diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 4509e99e..f13e126f 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -39,6 +39,8 @@ jobs: GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} + run: | docker compose up | tee &> archgw.logs & diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 3319f145..1538a964 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -32,6 +32,7 @@ jobs: GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} run: | python -mvenv venv source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh diff --git a/arch/supervisord.conf b/arch/supervisord.conf index d4d99494..3d1db8e2 100644 --- a/arch/supervisord.conf +++ b/arch/supervisord.conf @@ -9,7 +9,7 @@ stdout_logfile_maxbytes=0 stderr_logfile_maxbytes=0 [program:envoy] -command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info --log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | tee /var/log/envoy.log | while IFS= read -r line; do echo '[envoy_logs] ' \"$line\"; done" +command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug --log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | tee /var/log/envoy.log | while IFS= read -r line; do echo '[envoy_logs] ' \"$line\"; done" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 3e8e02ea..832993ae 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -21,6 +21,7 @@ SUPPORTED_PROVIDERS = [ "moonshotai", "zhipu", "qwen", + "amazon_bedrock", ] @@ -130,7 +131,10 @@ def validate_and_render_schema(): provider = model_name_tokens[0] # Validate azure_openai and ollama provider requires base_url if ( - provider == "azure_openai" or provider == "ollama" or provider == "qwen" + provider == "azure_openai" + or provider == "ollama" + or provider == "qwen" + or provider == "amazon_bedrock" ) and llm_provider.get("base_url") is None: raise Exception( f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" diff --git a/crates/Cargo.lock b/crates/Cargo.lock index db033142..fae883a5 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -91,6 +91,35 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9656b85088f8d9dc7ad40f9a6c7228e1e8447cdf4b046c87e152e0805dea02fa" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f5b3a7486f6690ba25952cabf1e7d75e34d69eaff5081904a47bc79074d6457" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -118,6 +147,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -206,6 +245,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cc" version = "1.2.26" @@ -282,6 +331,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -718,6 +776,7 @@ dependencies = [ name = "hermesllm" version = "0.1.0" dependencies = [ + "aws-smithy-eventstream", "serde", "serde_json", "serde_with", @@ -1315,6 +1374,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1473,6 +1541,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "overload" version = "0.1.1" @@ -2762,6 +2836,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "want" version = "0.3.1" diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 53313c36..108a83e2 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -2,9 +2,10 @@ use std::sync::Arc; use std::collections::HashMap; use bytes::Bytes; use common::configuration::{ModelAlias, ModelUsagePreference}; -use common::consts::ARCH_PROVIDER_HINT_HEADER; +use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_IS_STREAMING_HEADER}; use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::clients::SupportedAPIs; +use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; @@ -51,6 +52,7 @@ pub async fn chat( // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model let model_from_request = client_request.model().to_string(); + let is_streaming_request = client_request.is_streaming(); let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() { if let Some(model_alias) = model_aliases.get(&model_from_request) { debug!( @@ -77,9 +79,9 @@ pub async fn chat( // Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original) let chat_completions_request_for_arch_router: ChatCompletionsRequest = - match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { + match ProviderRequestType::try_from((client_request, &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, - Ok(ProviderRequestType::MessagesRequest(_)) => { + Ok(ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverseStream(_)) => { // This should not happen after conversion to OpenAI format warn!("Unexpected: got MessagesRequest after converting to OpenAI format"); let err_msg = "Request conversion failed".to_string(); @@ -179,6 +181,11 @@ pub async fn chat( header::HeaderValue::from_str(&model_name).unwrap(), ); + request_headers.insert( + header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), + header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), + ); + if let Some(trace_parent) = trace_parent { request_headers.insert( header::HeaderName::from_static("traceparent"), diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 301a8206..3cb98350 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -181,6 +181,8 @@ pub enum LlmProviderType { Zhipu, #[serde(rename = "qwen")] Qwen, + #[serde(rename = "amazon_bedrock")] + AmazonBedrock, } impl Display for LlmProviderType { @@ -200,6 +202,7 @@ impl Display for LlmProviderType { LlmProviderType::Moonshotai => write!(f, "moonshotai"), LlmProviderType::Zhipu => write!(f, "zhipu"), LlmProviderType::Qwen => write!(f, "qwen"), + LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"), } } } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 0eb5a036..6f5a5441 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -11,6 +11,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; +pub const ARCH_IS_STREAMING_HEADER: &str = "x-archgw-streaming-request"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MESSAGES_PATH: &str = "/v1/messages"; pub const HEALTHZ_PATH: &str = "/healthz"; diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index c995e85c..7ba48ea3 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -6,5 +6,6 @@ edition = "2021" [dependencies] serde = {version = "1.0.219", features = ["derive"]} serde_json = "1.0.140" -serde_with = "3.12.0" +serde_with = {version = "3.12.0", features = ["base64"]} thiserror = "2.0.12" +aws-smithy-eventstream = "0.60" diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs new file mode 100644 index 00000000..64bf6e74 --- /dev/null +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -0,0 +1,934 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; + +use thiserror::Error; +use std::collections::HashMap; + +use super::ApiDefinition; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; + +// ============================================================================ +// AMAZON BEDROCK CONVERSE API ENUMERATION +// ============================================================================ + +/// Enum for all supported Amazon Bedrock Converse APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AmazonBedrockApi { + Converse, + ConverseStream, +} + +impl ApiDefinition for AmazonBedrockApi { + fn endpoint(&self) -> &'static str { + match self { + AmazonBedrockApi::Converse => "/model/{modelId}/converse", + AmazonBedrockApi::ConverseStream => "/model/{modelId}/converse-stream", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + if endpoint.ends_with("/converse") { + Some(AmazonBedrockApi::Converse) + } else if endpoint.ends_with("/converse-stream") { + Some(AmazonBedrockApi::ConverseStream) + } else { + None + } + } + + fn supports_streaming(&self) -> bool { + match self { + AmazonBedrockApi::Converse => false, + AmazonBedrockApi::ConverseStream => true, + } + } + + fn supports_tools(&self) -> bool { + // Converse API has native tool support + true + } + + fn supports_vision(&self) -> bool { + // Converse API has native vision support + true + } + + fn all_variants() -> Vec { + vec![ + AmazonBedrockApi::Converse, + AmazonBedrockApi::ConverseStream, + ] + } +} + +// ============================================================================ +// CONVERSE REQUEST STRUCTURES +// ============================================================================ + +/// Amazon Bedrock Converse request +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseRequest { + /// The model ID or ARN to invoke + pub model_id: String, + /// The messages to send to the model + pub messages: Option>, + /// System prompts that provide instructions or context + pub system: Option>, + /// Inference configuration + #[serde(rename = "inferenceConfig")] + pub inference_config: Option, + /// Tool configuration for function calling + #[serde(rename = "toolConfig")] + pub tool_config: Option, + /// Guardrail configuration + #[serde(rename = "guardrailConfig")] + pub guardrail_config: Option, + /// Additional model-specific request fields + #[serde(rename = "additionalModelRequestFields")] + pub additional_model_request_fields: Option, + /// Additional model response field paths to return + #[serde(rename = "additionalModelResponseFieldPaths")] + pub additional_model_response_field_paths: Option>, + /// Performance configuration + #[serde(rename = "performanceConfig")] + pub performance_config: Option, + /// Prompt variables for Prompt management + #[serde(rename = "promptVariables")] + pub prompt_variables: Option>, + /// Request metadata for filtering logs + #[serde(rename = "requestMetadata")] + pub request_metadata: Option>, + /// Additional custom metadata (for internal use) + pub metadata: Option>, + /// Whether this request should use streaming endpoint (internal field, not serialized) + #[serde(skip)] + pub stream: bool, +} + +impl Default for ConverseRequest { + fn default() -> Self { + Self { + model_id: String::new(), + messages: None, + system: None, + inference_config: None, + tool_config: None, + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + stream: false, + } + } +} + +/// Amazon Bedrock ConverseStream request (same structure as Converse) +pub type ConverseStreamRequest = ConverseRequest; + +impl ProviderRequest for ConverseRequest { + fn model(&self) -> &str { + &self.model_id + } + + fn set_model(&mut self, model: String) { + self.model_id = model; + } + + fn is_streaming(&self) -> bool { + self.stream + } + + fn extract_messages_text(&self) -> String { + let mut text_parts = Vec::new(); + + // Extract text from messages + if let Some(messages) = &self.messages { + for message in messages { + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + text_parts.push(text.clone()); + } + ContentBlock::GuardContent { guard_content } => { + if let Some(guard_text) = &guard_content.text { + text_parts.push(guard_text.text.clone()); + } + } + _ => {} // Skip non-text content blocks + } + } + } + } + + // Extract text from system prompts + if let Some(system) = &self.system { + for system_block in system { + match system_block { + SystemContentBlock::Text { text } => { + text_parts.push(text.clone()); + } + SystemContentBlock::GuardContent { text: Some(guard_text) } => { + text_parts.push(guard_text.text.clone()); + } + SystemContentBlock::GuardContent { text: None } => { + // No text content in this guard content block + } + } + } + } + + text_parts.join(" ") + } + + fn get_recent_user_message(&self) -> Option { + self.messages + .as_ref()? + .iter() + .rev() // Start from the most recent message + .find(|msg| msg.role == ConversationRole::User) + .and_then(|msg| { + // Extract the first text content block from the user message + msg.content.iter().find_map(|content| { + match content { + ContentBlock::Text { text } => Some(text.clone()), + _ => None, + } + }) + }) + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize Bedrock request: {}", e), + source: Some(Box::new(e)), + }) + } + + fn metadata(&self) -> &Option> { + &self.metadata + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } +} + +// ============================================================================ +// CONVERSE RESPONSE STRUCTURES +// ============================================================================ + +/// Amazon Bedrock Converse response +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseResponse { + /// The result from the call to Converse + pub output: ConverseOutput, + /// The reason why the model stopped generating output + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + /// Token usage information + pub usage: BedrockTokenUsage, + /// Metrics for the call + pub metrics: Option, + /// Additional model response fields + #[serde(rename = "additionalModelResponseFields")] + pub additional_model_response_fields: Option, + /// Performance configuration used + #[serde(rename = "performanceConfig")] + pub performance_config: Option, + /// Trace information for guardrails + pub trace: Option, +} + +/// Amazon Bedrock ConverseStream response events +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ConverseStreamEvent { + MessageStart(MessageStartEvent), + ContentBlockStart(ContentBlockStartEvent), + ContentBlockDelta(ContentBlockDeltaEvent), + ContentBlockStop(ContentBlockStopEvent), + MessageStop(MessageStopEvent), + Metadata(ConverseStreamMetadataEvent), + // Error events + InternalServerException(BedrockException), + ModelStreamErrorException(BedrockException), + ServiceUnavailableException(BedrockException), + ThrottlingException(BedrockException), + ValidationException(BedrockException), +} + +// ============================================================================ +// MESSAGE AND CONTENT STRUCTURES +// ============================================================================ + +/// Message in a conversation +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Message { + /// Role of the message sender (user, assistant) + pub role: ConversationRole, + /// Content blocks in the message + pub content: Vec, +} + +/// Conversation role enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ConversationRole { + User, + Assistant, +} + +/// Content block in a message +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ContentBlock { + Text { + text: String + }, + Image { + image: ImageBlock + }, + Document { + document: DocumentBlock + }, + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseBlock, + }, + ToolResult { + #[serde(rename = "toolResult")] + tool_result: ToolResultBlock, + }, + GuardContent { + #[serde(rename = "guardContent")] + guard_content: GuardContentBlock, + }, +} + +/// Image block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ImageBlock { + pub source: ImageSource, +} + +/// Document block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct DocumentBlock { + pub source: DocumentSource, + pub name: Option, +} + +/// Tool use block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, + pub input: Value, +} + +/// Tool result block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolResultBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub content: Vec, + pub status: Option, +} + +/// Guard content block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardContentBlock { + pub text: Option, +} + +/// System content block for system prompts +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum SystemContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "guardContent")] + GuardContent { + text: Option, + }, +} + +/// Image source for vision capabilities +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ImageSource { + #[serde(rename = "base64")] + Base64 { + #[serde(rename = "mediaType")] + media_type: String, + data: String, + }, +} + +/// Document source for document processing +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum DocumentSource { + #[serde(rename = "base64")] + Base64 { + #[serde(rename = "mediaType")] + media_type: String, + data: String, + }, +} + +/// Tool result content block +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ToolResultContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { source: ImageSource }, + #[serde(rename = "json")] + Json { json: Value }, +} + +/// Tool result status +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ToolResultStatus { + Success, + Error, +} + +/// Guard content text with qualifiers +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardContentText { + pub text: String, + pub qualifiers: Option>, +} + +/// Guard content qualifier +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum GuardContentQualifier { + Grounding, + Relevance, + Harmfulness, + Helpfulness, +} + +// ============================================================================ +// INFERENCE AND TOOL CONFIGURATION +// ============================================================================ + +/// Inference configuration for the model +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InferenceConfiguration { + /// Maximum tokens to generate + #[serde(rename = "maxTokens")] + pub max_tokens: Option, + /// Temperature for randomness (0.0 to 1.0) + pub temperature: Option, + /// Top-p sampling parameter (0.0 to 1.0) + #[serde(rename = "topP")] + pub top_p: Option, + /// Stop sequences to halt generation + #[serde(rename = "stopSequences")] + pub stop_sequences: Option>, +} + +/// Tool configuration for function calling +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolConfiguration { + /// Available tools for the model + pub tools: Option>, + /// Tool choice configuration + #[serde(rename = "toolChoice")] + pub tool_choice: Option, +} + +/// Tool definition +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Tool { + ToolSpec { + #[serde(rename = "toolSpec")] + tool_spec: ToolSpecDefinition, + }, +} + +/// Tool specification definition +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolSpecDefinition { + pub name: String, + pub description: Option, + #[serde(rename = "inputSchema")] + pub input_schema: ToolInputSchema, +} + +/// Tool input schema +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolInputSchema { + pub json: Value, +} + +/// Tool choice configuration +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ToolChoice { + Auto { + #[serde(rename = "auto")] + auto: AutoChoice, + }, + Any { + #[serde(rename = "any")] + any: AnyChoice, + }, + Tool { + #[serde(rename = "tool")] + tool: ToolChoiceSpec, + }, +} + +/// Auto tool choice (empty object) +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AutoChoice {} + +/// Any tool choice (empty object) +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AnyChoice {} + +/// Specific tool choice +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolChoiceSpec { + pub name: String, +} + +// ============================================================================ +// GUARDRAIL CONFIGURATION +// ============================================================================ + +/// Guardrail configuration for content filtering +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailConfiguration { + /// Guardrail identifier + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + /// Guardrail version + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + /// Trace setting + pub trace: Option, +} + +/// Guardrail configuration for streaming (has additional field) +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailStreamConfiguration { + /// Guardrail identifier + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + /// Guardrail version + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + /// Stream processing mode + #[serde(rename = "streamProcessingMode")] + pub stream_processing_mode: Option, + /// Trace setting + pub trace: Option, +} + +/// Guardrail trace setting +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum GuardrailTrace { + Enabled, + Disabled, +} + +// ============================================================================ +// PERFORMANCE CONFIGURATION +// ============================================================================ + +/// Performance configuration for latency optimization +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PerformanceConfiguration { + /// Latency optimization setting + pub latency: Option, +} + +/// Performance latency setting +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum PerformanceLatency { + Standard, + Optimized, +} + +// ============================================================================ +// RESPONSE OUTPUT STRUCTURES +// ============================================================================ + +/// Converse output (union type) +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ConverseOutput { + Message { message: Message }, +} + +/// Stop reason enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + ToolUse, + MaxTokens, + StopSequence, + GuardrailIntervened, + ContentFiltered, +} + +/// Token usage information for Bedrock Converse API +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct BedrockTokenUsage { + /// Input tokens processed + #[serde(rename = "inputTokens")] + pub input_tokens: u32, + /// Output tokens generated + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + /// Total tokens used + #[serde(rename = "totalTokens")] + pub total_tokens: u32, + /// Server tool usage (for function calling) + #[serde(rename = "serverToolUsage")] + pub server_tool_usage: Option, + /// Cache read input tokens + #[serde(rename = "cacheReadInputTokens")] + pub cache_read_input_tokens: Option, + /// Cache write input tokens + #[serde(rename = "cacheWriteInputTokens")] + pub cache_write_input_tokens: Option, +} + +/// Converse metrics +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseMetrics { + /// Latency in milliseconds + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +/// Converse trace information +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseTrace { + /// Guardrail trace information + pub guardrail: Option, + /// Prompt router trace information + #[serde(rename = "promptRouter")] + pub prompt_router: Option, +} + +/// Guardrail trace assessment (simplified) +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailTraceAssessment { + /// Action reason + #[serde(rename = "actionReason")] + pub action_reason: Option, + /// Model output + #[serde(rename = "modelOutput")] + pub model_output: Option>, + /// Input assessment + #[serde(rename = "inputAssessment")] + pub input_assessment: Option>, + /// Output assessments + #[serde(rename = "outputAssessments")] + pub output_assessments: Option>>, +} + +/// Prompt router trace +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PromptRouterTrace { + /// Invoked model ID + #[serde(rename = "invokedModelId")] + pub invoked_model_id: String, +} + +// ============================================================================ +// STREAMING EVENT STRUCTURES +// ============================================================================ + +/// Message start event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessageStartEvent { + /// Role of the message + pub role: ConversationRole, +} + +/// Content block start event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockStartEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, + /// Start information + pub start: ContentBlockStart, +} + +/// Content block start information +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ContentBlockStart { + #[serde(rename = "toolUse")] + ToolUse { + #[serde(rename = "toolUseId")] + tool_use_id: String, + name: String, + }, +} + +/// Content block delta event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockDeltaEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, + /// Delta information + pub delta: ContentBlockDelta, +} + +/// Content block delta information +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ContentBlockDelta { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "toolUse")] + ToolUse { input: String }, +} + +/// Content block stop event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockStopEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: u32, +} + +/// Message stop event +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessageStopEvent { + /// Stop reason + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + /// Additional model response fields + #[serde(rename = "additionalModelResponseFields")] + pub additional_model_response_fields: Option, +} + +/// Stream metadata event +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamMetadataEvent { + /// Token usage + pub usage: BedrockTokenUsage, + /// Stream metrics + pub metrics: Option, + /// Trace information + pub trace: Option, + /// Performance configuration + #[serde(rename = "performanceConfig")] + pub performance_config: Option, +} + +/// Stream metrics +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamMetrics { + /// Latency in milliseconds + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +/// Stream trace information +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamTrace { + /// Guardrail trace + pub guardrail: Option, + /// Prompt router trace + #[serde(rename = "promptRouter")] + pub prompt_router: Option, +} + +// ============================================================================ +// PROMPT MANAGEMENT +// ============================================================================ + +/// Prompt variable values for Prompt management +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum PromptVariableValues { + Text { text: String }, +} + +// ============================================================================ +// ERROR TYPES +// ============================================================================ + +/// Bedrock exception structure +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct BedrockException { + /// Exception message + pub message: Option, + /// Original status code (for model errors) + #[serde(rename = "originalStatusCode")] + pub original_status_code: Option, + /// Resource name (for model errors) + #[serde(rename = "resourceName")] + pub resource_name: Option, + /// Original message (for stream errors) + #[serde(rename = "originalMessage")] + pub original_message: Option, +} + +/// Bedrock-specific error types +#[derive(Error, Debug)] +pub enum BedrockError { + #[error("Access denied: {message}")] + AccessDenied { message: String }, + + #[error("Internal server error: {message}")] + InternalServer { message: String }, + + #[error("Model error: {message}")] + ModelError { + message: String, + original_status_code: Option, + resource_name: Option, + }, + + #[error("Model not ready: {message}")] + ModelNotReady { message: String }, + + #[error("Model timeout: {message}")] + ModelTimeout { message: String }, + + #[error("Resource not found: {message}")] + ResourceNotFound { message: String }, + + #[error("Service unavailable: {message}")] + ServiceUnavailable { message: String }, + + #[error("Throttling: {message}")] + Throttling { message: String }, + + #[error("Validation error: {message}")] + Validation { message: String }, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +// ============================================================================ +// TRAIT IMPLEMENTATIONS +// ============================================================================ + +// Note: Trait implementations will be added later when we implement transformations +// For now, we're focusing on modeling the request/response shapes + +impl crate::providers::response::TokenUsage for BedrockTokenUsage { + fn completion_tokens(&self) -> usize { + self.output_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.input_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_tool_serialization_format() { + let tool = Tool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: "get_weather".to_string(), + description: Some("Get the current weather for a specified city".to_string()), + input_schema: ToolInputSchema { + json: json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to get weather for" + } + }, + "required": ["city"] + }), + }, + }, + }; + + let serialized = serde_json::to_value(&tool).unwrap(); + println!("Tool serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); + + // Verify the structure matches Bedrock API expectations + assert!(serialized.get("toolSpec").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + let tool_spec = serialized.get("toolSpec").unwrap(); + assert_eq!(tool_spec.get("name").unwrap(), "get_weather"); + assert_eq!(tool_spec.get("description").unwrap(), "Get the current weather for a specified city"); + assert!(tool_spec.get("inputSchema").is_some()); + } + + #[test] + fn test_tool_choice_serialization_format() { + // Test Auto choice + let auto_choice = ToolChoice::Auto { auto: AutoChoice {} }; + let serialized = serde_json::to_value(&auto_choice).unwrap(); + println!("Auto ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); + + assert!(serialized.get("auto").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + // Test Tool choice + let tool_choice = ToolChoice::Tool { + tool: ToolChoiceSpec { + name: "get_weather".to_string() + } + }; + let serialized = serde_json::to_value(&tool_choice).unwrap(); + println!("Tool ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); + + assert!(serialized.get("tool").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + let tool_spec = serialized.get("tool").unwrap(); + assert_eq!(tool_spec.get("name").unwrap(), "get_weather"); + } +} diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index abfde5b7..6a06e2f1 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; -use crate::clients::transformer::ExtractText; +use crate::transforms::lib::ExtractText; use crate::{MESSAGES_PATH}; // Enum for all supported Anthropic APIs diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index b175988c..f570ac6e 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -1,8 +1,13 @@ pub mod anthropic; pub mod openai; -pub use anthropic::*; -pub use openai::*; +pub mod amazon_bedrock; +// Explicit exports to avoid naming conflicts +pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent}; +pub use openai::{OpenAIApi, ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse}; +pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice}; +pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest}; +pub use amazon_bedrock::{Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice}; pub trait ApiDefinition { /// Returns the endpoint path for this API diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 63b5fc58..3f8fe985 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -8,7 +8,7 @@ use thiserror::Error; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; use super::ApiDefinition; -use crate::clients::transformer::{ExtractText}; +use crate::transforms::lib::ExtractText; use crate::{CHAT_COMPLETIONS_PATH}; // ============================================================================ diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index e5c01f05..264f2668 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -1,27 +1,5 @@ -//! Supported endpoint registry for LLM APIs -//! -//! This module provides a simple registry to check which API endpoint paths -//! we support across different providers. -//! -//! # Examples -//! -//! ```rust -//! use hermesllm::clients::endpoints::supported_endpoints; -//! -//! // Check if we support an endpoint -//! use hermesllm::clients::endpoints::SupportedAPIs; -//! assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some()); -//! assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); -//! assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some()); -//! -//! // Get all supported endpoints -//! let endpoints = supported_endpoints(); -//! assert_eq!(endpoints.len(), 2); -//! assert!(endpoints.contains(&"/v1/chat/completions")); -//! assert!(endpoints.contains(&"/v1/messages")); -//! ``` - -use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId}; +use crate::{ProviderId}; +use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi, ApiDefinition}; use std::fmt; /// Unified enum representing all supported API endpoints across providers @@ -31,6 +9,14 @@ pub enum SupportedAPIs { AnthropicMessagesAPI(AnthropicApi), } +#[derive(Debug, Clone, PartialEq)] +pub enum SupportedUpstreamAPIs { + OpenAIChatCompletions(OpenAIApi), + AnthropicMessagesAPI(AnthropicApi), + AmazonBedrockConverse(AmazonBedrockApi), + AmazonBedrockConverseStream(AmazonBedrockApi), +} + impl fmt::Display for SupportedAPIs { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -62,12 +48,21 @@ impl SupportedAPIs { } } - pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String { + pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str, is_streaming: bool) -> String { let default_endpoint = "/v1/chat/completions".to_string(); match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { match provider_id { ProviderId::Anthropic => "/v1/messages".to_string(), + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") && !is_streaming { + format!("/model/{}/converse", model_id) + } else if request_path.starts_with("/v1/") && is_streaming { + format!("/model/{}/converse-stream", model_id) + } else { + default_endpoint + } + } _ => default_endpoint, } } @@ -108,6 +103,16 @@ impl SupportedAPIs { default_endpoint } } + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") && !is_streaming { + format!("/model/{}/converse", model_id) + } else if request_path.starts_with("/v1/") && is_streaming { + format!("/model/{}/converse-stream", model_id) + } + else { + default_endpoint + } + } _ => default_endpoint, } } @@ -155,10 +160,10 @@ mod tests { fn test_is_supported_endpoint() { // OpenAI endpoints assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some()); - // Anthropic endpoints assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); + // Unsupported endpoints assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some()); assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some()); @@ -168,9 +173,10 @@ mod tests { #[test] fn test_supported_endpoints() { let endpoints = supported_endpoints(); - assert_eq!(endpoints.len(), 2); + assert_eq!(endpoints.len(), 2); // We have 2 APIs defined assert!(endpoints.contains(&"/v1/chat/completions")); assert!(endpoints.contains(&"/v1/messages")); + } #[test] @@ -203,7 +209,6 @@ mod tests { for endpoint in anthropic_endpoints { assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint); } - // Total should match assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()); } diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index f6e508d4..92daa298 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -1,1089 +1,7 @@ -//! API request/response transformers between Anthropic and OpenAI APIs -//! -//! This module provides clean, bidirectional conversion between different LLM API formats -//! using Rust's standard `TryFrom` and `Into` traits. The organization follows a logical flow: -//! -//! 1. **Main Request Transformations** - Core TryFrom implementations for requests -//! 2. **Main Response Transformations** - Core TryFrom implementations for responses -//! 3. **Streaming Transformations** - Bidirectional streaming event conversion -//! 4. **Standard Rust Trait Implementations** - Into/TryFrom implementations for type conversions -//! 5. **Helper Functions** - Utility functions organized by domain -//! -//! # Examples -//! -//! ```rust -//! use hermesllm::apis::{ -//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, -//! MessagesMessageContent, MessagesSystemPrompt, -//! }; -//! use hermesllm::clients::TransformError; -//! use std::convert::TryInto; -//! -//! // Transform Anthropic to OpenAI -//! let anthropic_req = MessagesRequest { -//! model: "claude-3-sonnet".to_string(), -//! system: None, -//! messages: vec![], -//! max_tokens: 1024, -//! container: None, -//! mcp_servers: None, -//! service_tier: None, -//! thinking: None, -//! temperature: None, -//! top_p: None, -//! top_k: None, -//! stream: None, -//! stop_sequences: None, -//! tools: None, -//! tool_choice: None, -//! metadata: None, -//! }; -//! let openai_req: Result = anthropic_req.try_into(); -//! # Ok::<(), Box>(()) -//! ``` +// Re-export new transformation modules for backward compatibility -use serde_json::Value; -use std::time::{SystemTime, UNIX_EPOCH}; -use crate::apis::*; -use super::TransformError; -// ============================================================================ -// CONSTANTS -// ============================================================================ - -/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified -const DEFAULT_MAX_TOKENS: u32 = 4096; - -// ============================================================================ -// UTILITY TRAITS - Shared traits for content manipulation -// ============================================================================ - -/// Trait for extracting text content from various types -pub trait ExtractText { - fn extract_text(&self) -> String; -} - -/// Trait for utility functions on content collections -trait ContentUtils { - fn extract_tool_calls(&self) -> Result>, TransformError>; - fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; -} - -// ============================================================================ -// MAIN REQUEST TRANSFORMATIONS -// ============================================================================ - -type AnthropicMessagesRequest = MessagesRequest; - - -impl TryFrom for ChatCompletionsRequest { - type Error = TransformError; - - fn try_from(req: AnthropicMessagesRequest) -> Result { - let mut openai_messages: Vec = Vec::new(); - - // Convert system prompt to system message if present - if let Some(system) = req.system { - openai_messages.push(system.into()); - } - - // Convert messages - for message in req.messages { - let converted_messages: Vec = message.try_into()?; - openai_messages.extend(converted_messages); - } - - // Convert tools and tool choice - let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); - let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice); - - let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { - model: req.model, - messages: openai_messages, - temperature: req.temperature, - top_p: req.top_p, - max_completion_tokens: Some(req.max_tokens), - stream: req.stream, - stop: req.stop_sequences, - tools: openai_tools, - tool_choice: openai_tool_choice, - parallel_tool_calls, - ..Default::default() - }; - _chat_completions_req.suppress_max_tokens_if_o3(); - _chat_completions_req.fix_temperature_if_gpt5(); - Ok(_chat_completions_req) - } -} - -impl TryFrom for AnthropicMessagesRequest { - type Error = TransformError; - - fn try_from(req: ChatCompletionsRequest) -> Result { - let mut system_prompt = None; - let mut messages = Vec::new(); - - for message in req.messages { - match message.role { - Role::System => { - system_prompt = Some(message.into()); - } - _ => { - let anthropic_message: MessagesMessage = message.try_into()?; - messages.push(anthropic_message); - } - } - } - - // Convert tools and tool choice - let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); - let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); - - Ok(AnthropicMessagesRequest { - model: req.model, - system: system_prompt, - messages, - max_tokens: req.max_completion_tokens - .or(req.max_tokens) - .unwrap_or(DEFAULT_MAX_TOKENS), - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: req.temperature, - top_p: req.top_p, - top_k: None, // OpenAI doesn't have top_k - stream: req.stream, - stop_sequences: req.stop, - tools: anthropic_tools, - tool_choice: anthropic_tool_choice, - metadata: None, - }) - } -} - -// ============================================================================ -// MAIN RESPONSE TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsResponse { - type Error = TransformError; - - fn try_from(resp: MessagesResponse) -> Result { - let content = convert_anthropic_content_to_openai(&resp.content)?; - let finish_reason: FinishReason = resp.stop_reason.into(); - let tool_calls = resp.content.extract_tool_calls()?; - - // Convert MessageContent to String for response - let content_string = match content { - MessageContent::Text(text) => Some(text), - MessageContent::Parts(parts) => { - let text = parts.extract_text(); - if text.is_empty() { None } else { Some(text) } - } - }; - - let message = ResponseMessage { - role: Role::Assistant, - content: content_string, - refusal: None, - annotations: None, - audio: None, - function_call: None, - tool_calls, - }; - - let choice = Choice { - index: 0, - message, - finish_reason: Some(finish_reason), - logprobs: None, - }; - - let usage = Usage { - prompt_tokens: resp.usage.input_tokens, - completion_tokens: resp.usage.output_tokens, - total_tokens: resp.usage.input_tokens + resp.usage.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }; - - Ok(ChatCompletionsResponse { - id: resp.id, - object: Some("chat.completion".to_string()), - created: current_timestamp(), - model: resp.model, - choices: vec![choice], - usage, - system_fingerprint: None, - service_tier: None, - }) - } -} - -impl TryFrom for MessagesResponse { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsResponse) -> Result { - let choice = resp.choices.into_iter().next() - .ok_or_else(|| TransformError::MissingField("choices".to_string()))?; - - let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; - let stop_reason = choice.finish_reason - .map(|fr| fr.into()) - .unwrap_or(MessagesStopReason::EndTurn); - - let usage = MessagesUsage { - input_tokens: resp.usage.prompt_tokens, - output_tokens: resp.usage.completion_tokens, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }; - - Ok(MessagesResponse { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content, - model: resp.model, - stop_reason, - stop_sequence: None, - usage, - container: None, - }) - } -} - -// ============================================================================ -// STREAMING TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: MessagesStreamEvent) -> Result { - match event { - MessagesStreamEvent::MessageStart { message } => { - Ok(create_openai_chunk( - &message.id, - &message.model, - MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) - } - - MessagesStreamEvent::ContentBlockDelta { delta, .. } => { - convert_content_delta(delta) - } - - MessagesStreamEvent::ContentBlockStop { .. } => { - Ok(create_empty_openai_chunk()) - } - - MessagesStreamEvent::MessageDelta { delta, usage } => { - let finish_reason: Option = Some(delta.stop_reason.into()); - let openai_usage: Option = Some(usage.into()); - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason, - openai_usage, - )) - } - - MessagesStreamEvent::MessageStop => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(FinishReason::Stop), - None, - )) - } - - MessagesStreamEvent::Ping => { - Ok(ChatCompletionsStreamResponse { - id: "stream".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: "unknown".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }) - } - } - } -} - -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsStreamResponse) -> Result { - if resp.choices.is_empty() { - return Ok(MessagesStreamEvent::Ping); - } - - let choice = &resp.choices[0]; - - // Handle final chunk with usage - let has_usage = resp.usage.is_some(); - if let Some(usage) = resp.usage { - if let Some(finish_reason) = &choice.finish_reason { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: usage.into(), - }); - } - } - - // Handle role start - if let Some(Role::Assistant) = choice.delta.role { - return Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: resp.model, - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }); - } - - // Handle content delta - if let Some(content) = &choice.delta.content { - if !content.is_empty() { - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: content.clone(), - }, - }); - } - } - - // Handle tool calls - if let Some(tool_calls) = &choice.delta.tool_calls { - return convert_tool_call_deltas(tool_calls.clone()); - } - - // Handle finish reason - generate MessageDelta only (MessageStop comes later) - if let Some(finish_reason) = &choice.finish_reason { - // If we have usage data, it was already handled above - // If not, we need to generate MessageDelta with default usage - if !has_usage { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }); - } - // If usage was already handled above, we don't need to do anything more here - // MessageStop will be handled when [DONE] is encountered - } - - // Default to ping for unhandled cases - Ok(MessagesStreamEvent::Ping) - } -} - -// ============================================================================ -// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for conversions -// ============================================================================ - -// System Prompt Conversions -impl Into for MessagesSystemPrompt { - fn into(self) -> Message { - let system_content = match self { - MessagesSystemPrompt::Single(text) => MessageContent::Text(text), - MessagesSystemPrompt::Blocks(blocks) => { - MessageContent::Text(blocks.extract_text()) - } - }; - - Message { - role: Role::System, - content: system_content, - name: None, - tool_calls: None, - tool_call_id: None, - } - } -} - -impl Into for Message { - fn into(self) -> MessagesSystemPrompt { - let system_text = match self.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text() - }; - MessagesSystemPrompt::Single(system_text) - } -} - -// Message Conversions -impl TryFrom for Vec { - type Error = TransformError; - - fn try_from(message: MessagesMessage) -> Result { - let mut result = Vec::new(); - - match message.content { - MessagesMessageContent::Single(text) => { - result.push(Message { - role: message.role.into(), - content: MessageContent::Text(text), - name: None, - tool_calls: None, - tool_call_id: None, - }); - } - MessagesMessageContent::Blocks(blocks) => { - let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?; - // Add tool result messages - for (tool_use_id, result_text, _is_error) in tool_results { - result.push(Message { - role: Role::Tool, - content: MessageContent::Text(result_text), - name: None, - tool_calls: None, - tool_call_id: Some(tool_use_id), - }); - } - - // Only create main message if there's actual content or tool calls - // Skip creating empty content messages (e.g., when message only contains tool_result blocks) - if !content_parts.is_empty() || !tool_calls.is_empty() { - let content = build_openai_content(content_parts, &tool_calls); - let main_message = Message { - role: message.role.into(), - content, - name: None, - tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, - tool_call_id: None, - }; - result.push(main_message); - } - } - } - - Ok(result) - } -} - -impl TryFrom for MessagesMessage { - type Error = TransformError; - - fn try_from(message: Message) -> Result { - let role = match message.role { - Role::User => MessagesRole::User, - Role::Assistant => MessagesRole::Assistant, - Role::Tool => { - // Tool messages become user messages with tool results - let tool_call_id = message.tool_call_id - .ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; - - return Ok(MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Blocks(vec![ - MessagesContentBlock::ToolResult { - tool_use_id: tool_call_id, - is_error: None, - content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { - text: message.content.extract_text(), - cache_control: None, - }]), - cache_control: None, - }, - ]), - }); - } - Role::System => { - return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string())); - } - }; - - let content_blocks = convert_openai_message_to_anthropic_content(&message)?; - let content = build_anthropic_content(content_blocks); - - Ok(MessagesMessage { role, content }) - } -} - -// Role Conversions -impl Into for MessagesRole { - fn into(self) -> Role { - match self { - MessagesRole::User => Role::User, - MessagesRole::Assistant => Role::Assistant, - } - } -} - -// Content Utilities -impl ContentUtils for Vec { - fn extract_tool_calls(&self) -> Result>, TransformError> { - let mut tool_calls = Vec::new(); - - for block in self { - match block { - MessagesContentBlock::ToolUse { id, name, input, .. } | - MessagesContentBlock::ServerToolUse { id, name, input } | - MessagesContentBlock::McpToolUse { id, name, input } => { - let arguments = serde_json::to_string(&input)?; - tool_calls.push(ToolCall { - id: id.clone(), - call_type: "function".to_string(), - function: FunctionCall { name: name.clone(), arguments }, - }); - } - _ => continue, - } - } - - Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) }) - } - - fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> { - let mut content_parts = Vec::new(); - let mut tool_calls = Vec::new(); - let mut tool_results = Vec::new(); - - for block in self { - match block { - MessagesContentBlock::Text { text, .. } => { - content_parts.push(ContentPart::Text { text: text.clone() }); - } - MessagesContentBlock::Image { source } => { - let url = convert_image_source_to_url(source); - content_parts.push(ContentPart::ImageUrl { - image_url: ImageUrl { - url, - detail: Some("auto".to_string()), - }, - }); - } - MessagesContentBlock::ToolUse { id, name, input, .. } | - MessagesContentBlock::ServerToolUse { id, name, input } | - MessagesContentBlock::McpToolUse { id, name, input } => { - let arguments = serde_json::to_string(&input)?; - tool_calls.push(ToolCall { - id: id.clone(), - call_type: "function".to_string(), - function: FunctionCall { name: name.clone(), arguments }, - }); - } - MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => { - let result_text = content.extract_text(); - tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); - } - MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } | - MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } | - MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => { - let result_text = content.extract_text(); - tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); - } - _ => { - // Skip unsupported content types - continue; - } - } - } - - Ok((content_parts, tool_calls, tool_results)) - } -} - -// Stop Reason Conversions -impl Into for MessagesStopReason { - fn into(self) -> FinishReason { - match self { - MessagesStopReason::EndTurn => FinishReason::Stop, - MessagesStopReason::MaxTokens => FinishReason::Length, - MessagesStopReason::StopSequence => FinishReason::Stop, - MessagesStopReason::ToolUse => FinishReason::ToolCalls, - MessagesStopReason::PauseTurn => FinishReason::Stop, - MessagesStopReason::Refusal => FinishReason::ContentFilter, - } - } -} - -impl Into for FinishReason { - fn into(self) -> MessagesStopReason { - match self { - FinishReason::Stop => MessagesStopReason::EndTurn, - FinishReason::Length => MessagesStopReason::MaxTokens, - FinishReason::ToolCalls => MessagesStopReason::ToolUse, - FinishReason::ContentFilter => MessagesStopReason::Refusal, - FinishReason::FunctionCall => MessagesStopReason::ToolUse, - } - } -} - -// Usage Conversions -impl Into for MessagesUsage { - fn into(self) -> Usage { - Usage { - prompt_tokens: self.input_tokens, - completion_tokens: self.output_tokens, - total_tokens: self.input_tokens + self.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - } - } -} - -impl Into for Usage { - fn into(self) -> MessagesUsage { - MessagesUsage { - input_tokens: self.prompt_tokens, - output_tokens: self.completion_tokens, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - } - } -} - -// ============================================================================ -// HELPER FUNCTIONS - Organized by domain -// ============================================================================ - -/// Helper to create a current unix timestamp -fn current_timestamp() -> u64 { - SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() -} - -/// Helper to create OpenAI streaming chunk -fn create_openai_chunk( - id: &str, - model: &str, - delta: MessageDelta, - finish_reason: Option, - usage: Option -) -> ChatCompletionsStreamResponse { - ChatCompletionsStreamResponse { - id: id.to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: model.to_string(), - choices: vec![StreamChoice { - index: 0, - delta, - finish_reason, - logprobs: None, - }], - usage, - system_fingerprint: None, - service_tier: None, - } -} - -/// Helper to create empty OpenAI streaming chunk -fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { - create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - ) -} - -/// Convert Anthropic tools to OpenAI format -fn convert_anthropic_tools(tools: Vec) -> Vec { - tools.into_iter() - .map(|tool| Tool { - tool_type: "function".to_string(), - function: Function { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - strict: None, - }, - }) - .collect() -} - -/// Convert OpenAI tools to Anthropic format -fn convert_openai_tools(tools: Vec) -> Vec { - tools.into_iter() - .map(|tool| MessagesTool { - name: tool.function.name, - description: tool.function.description, - input_schema: tool.function.parameters, - }) - .collect() -} - -/// Convert Anthropic tool choice to OpenAI format -fn convert_anthropic_tool_choice(tool_choice: Option) -> (Option, Option) { - match tool_choice { - Some(choice) => { - let openai_choice = match choice.kind { - MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto), - MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required), - MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None), - MessagesToolChoiceType::Tool => { - if let Some(name) = choice.name { - ToolChoice::Function { - choice_type: "function".to_string(), - function: FunctionChoice { name }, - } - } else { - ToolChoice::Type(ToolChoiceType::Auto) - } - } - }; - let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); - (Some(openai_choice), parallel) - } - None => (None, None) - } -} - -/// Convert OpenAI tool choice to Anthropic format -fn convert_openai_tool_choice( - tool_choice: Option, - parallel_tool_calls: Option -) -> Option { - tool_choice.map(|choice| { - match choice { - ToolChoice::Type(tool_type) => match tool_type { - ToolChoiceType::Auto => MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::Required => MessagesToolChoice { - kind: MessagesToolChoiceType::Any, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::None => MessagesToolChoice { - kind: MessagesToolChoiceType::None, - name: None, - disable_parallel_tool_use: None, - }, - }, - ToolChoice::Function { function, .. } => MessagesToolChoice { - kind: MessagesToolChoiceType::Tool, - name: Some(function.name), - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - } - }) -} - -/// Build OpenAI message content from parts and tool calls -fn build_openai_content(content_parts: Vec, tool_calls: &[ToolCall]) -> MessageContent { - if content_parts.len() == 1 && tool_calls.is_empty() { - match &content_parts[0] { - ContentPart::Text { text } => MessageContent::Text(text.clone()), - _ => MessageContent::Parts(content_parts), - } - } else if content_parts.is_empty() { - MessageContent::Text("".to_string()) - } else { - MessageContent::Parts(content_parts) - } -} - -/// Build Anthropic message content from content blocks -fn build_anthropic_content(content_blocks: Vec) -> MessagesMessageContent { - if content_blocks.len() == 1 { - match &content_blocks[0] { - MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()), - _ => MessagesMessageContent::Blocks(content_blocks), - } - } else if content_blocks.is_empty() { - MessagesMessageContent::Single("".to_string()) - } else { - MessagesMessageContent::Blocks(content_blocks) - } -} - -/// Convert Anthropic content blocks to OpenAI message content -fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result { - let mut text_parts = Vec::new(); - - for block in content { - match block { - MessagesContentBlock::Text { text, .. } => { - text_parts.push(text.clone()); - } - MessagesContentBlock::Thinking { thinking, .. } => { - text_parts.push(format!("thinking: {}", thinking)); - } - _ => { - // Skip other content types for basic text conversion - continue; - } - } - } - - Ok(MessageContent::Text(text_parts.join("\n"))) -} - -/// Convert OpenAI message to Anthropic content blocks -fn convert_openai_message_to_anthropic_content(message: &Message) -> Result, TransformError> { - let mut blocks = Vec::new(); - - // Handle regular content - match &message.content { - MessageContent::Text(text) => { - if !text.is_empty() { - blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); - } - } - MessageContent::Parts(parts) => { - for part in parts { - match part { - ContentPart::Text { text } => { - blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); - } - ContentPart::ImageUrl { image_url } => { - let source = convert_image_url_to_source(image_url); - blocks.push(MessagesContentBlock::Image { source }); - } - } - } - } - } - - // Handle tool calls - if let Some(tool_calls) = &message.tool_calls { - for tool_call in tool_calls { - let input: Value = serde_json::from_str(&tool_call.function.arguments)?; - blocks.push(MessagesContentBlock::ToolUse { - id: tool_call.id.clone(), - name: tool_call.function.name.clone(), - input, - cache_control: None, - }); - } - } - - Ok(blocks) -} - -/// Convert image source to URL -fn convert_image_source_to_url(source: &MessagesImageSource) -> String { - match source { - MessagesImageSource::Base64 { media_type, data } => { - format!("data:{};base64,{}", media_type, data) - } - MessagesImageSource::Url { url } => url.clone(), - } -} - -/// Convert image URL to Anthropic image source -fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource { - if image_url.url.starts_with("data:") { - // Parse data URL - let parts: Vec<&str> = image_url.url.splitn(2, ',').collect(); - if parts.len() == 2 { - let header = parts[0]; - let data = parts[1]; - let media_type = header - .strip_prefix("data:") - .and_then(|s| s.split(';').next()) - .unwrap_or("image/jpeg") - .to_string(); - - MessagesImageSource::Base64 { - media_type, - data: data.to_string(), - } - } else { - MessagesImageSource::Url { url: image_url.url.clone() } - } - } else { - MessagesImageSource::Url { url: image_url.url.clone() } - } -} - -/// Convert content block start to OpenAI chunk -fn convert_content_block_start(content_block: MessagesContentBlock) -> Result { - match content_block { - MessagesContentBlock::Text { .. } => { - // No immediate output for text block start - Ok(create_empty_openai_chunk()) - } - MessagesContentBlock::ToolUse { id, name, .. } | - MessagesContentBlock::ServerToolUse { id, name, .. } | - MessagesContentBlock::McpToolUse { id, name, .. } => { - // Tool use start → OpenAI chunk with tool_calls - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some(id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )) - } - _ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())), - } -} - -/// Convert content delta to OpenAI chunk -fn convert_content_delta(delta: MessagesContentDelta) -> Result { - match delta { - MessagesContentDelta::TextDelta { text } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - MessagesContentDelta::ThinkingDelta { thinking } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(format!("thinking: {}", thinking)), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - MessagesContentDelta::InputJsonDelta { partial_json } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(partial_json), - }), - }]), - }, - None, - None, - )) - } - } -} - -/// Convert tool call deltas to Anthropic stream events -fn convert_tool_call_deltas(tool_calls: Vec) -> Result { - for tool_call in tool_calls { - if let Some(id) = &tool_call.id { - // Tool call start - if let Some(function) = &tool_call.function { - if let Some(name) = &function.name { - return Ok(MessagesStreamEvent::ContentBlockStart { - index: tool_call.index, - content_block: MessagesContentBlock::ToolUse { - id: id.clone(), - name: name.clone(), - input: Value::Object(serde_json::Map::new()), - cache_control: None, - }, - }); - } - } - } else if let Some(function) = &tool_call.function { - if let Some(arguments) = &function.arguments { - // Tool arguments delta - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: tool_call.index, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: arguments.clone(), - }, - }); - } - } - } - - // Fallback to ping if no valid tool call found - Ok(MessagesStreamEvent::Ping) -} +//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING // ============================================================================ // TESTS @@ -1091,8 +9,11 @@ fn convert_tool_call_deltas(tool_calls: Vec) -> Result for ProviderId { @@ -39,7 +40,8 @@ impl From<&str> for ProviderId { "ollama" => ProviderId::Ollama, "moonshotai" => ProviderId::Moonshotai, "zhipu" => ProviderId::Zhipu, - "qwen" => ProviderId::Qwen, // alias for Zhipu + "qwen" => ProviderId::Qwen, // alias for Qwen + "amazon_bedrock" => ProviderId::AmazonBedrock, _ => panic!("Unknown provider: {}", value), } } @@ -47,11 +49,11 @@ impl From<&str> for ProviderId { impl ProviderId { /// Given a client API, return the compatible upstream API for this provider - pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs { + pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs, is_streaming: bool) -> SupportedUpstreamAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs - (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), - (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI-compatible providers only support OpenAI chat completions (ProviderId::OpenAI @@ -68,7 +70,7 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), (ProviderId::OpenAI | ProviderId::Groq @@ -84,7 +86,23 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + SupportedAPIs::OpenAIChatCompletions(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + + // Amazon Bedrock natively supports Bedrock APIs + (ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + }, + (ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + }, } } } @@ -107,6 +125,7 @@ impl Display for ProviderId { ProviderId::Moonshotai => write!(f, "moonshotai"), ProviderId::Zhipu => write!(f, "zhipu"), ProviderId::Qwen => write!(f, "qwen"), + ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"), } } } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 3603edf2..87c4aee8 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,6 +1,9 @@ use crate::apis::openai::ChatCompletionsRequest; use crate::apis::anthropic::MessagesRequest; + +use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; use crate::clients::endpoints::SupportedAPIs; +use crate::clients::endpoints::SupportedUpstreamAPIs; use serde_json::Value; use std::error::Error; @@ -10,6 +13,8 @@ use std::collections::HashMap; pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), MessagesRequest(MessagesRequest), + BedrockConverse(ConverseRequest), + BedrockConverseStream(ConverseStreamRequest), //add more request types here } pub trait ProviderRequest: Send + Sync { @@ -42,6 +47,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.model(), Self::MessagesRequest(r) => r.model(), + Self::BedrockConverse(r) => r.model(), + Self::BedrockConverseStream(r) => r.model(), } } @@ -49,6 +56,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.set_model(model), Self::MessagesRequest(r) => r.set_model(model), + Self::BedrockConverse(r) => r.set_model(model), + Self::BedrockConverseStream(r) => r.set_model(model), } } @@ -56,6 +65,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.is_streaming(), Self::MessagesRequest(r) => r.is_streaming(), + Self::BedrockConverse(_) => false, + Self::BedrockConverseStream(_) => true, } } @@ -63,6 +74,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.extract_messages_text(), Self::MessagesRequest(r) => r.extract_messages_text(), + Self::BedrockConverse(r) => r.extract_messages_text(), + Self::BedrockConverseStream(r) => r.extract_messages_text(), } } @@ -70,6 +83,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.get_recent_user_message(), Self::MessagesRequest(r) => r.get_recent_user_message(), + Self::BedrockConverse(r) => r.get_recent_user_message(), + Self::BedrockConverseStream(r) => r.get_recent_user_message(), } } @@ -77,6 +92,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.to_bytes(), Self::MessagesRequest(r) => r.to_bytes(), + Self::BedrockConverse(r) => r.to_bytes(), + Self::BedrockConverseStream(r) => r.to_bytes(), } } @@ -84,6 +101,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.metadata(), Self::MessagesRequest(r) => r.metadata(), + Self::BedrockConverse(r) => r.metadata(), + Self::BedrockConverseStream(r) => r.metadata(), } } @@ -91,6 +110,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key), Self::MessagesRequest(r) => r.remove_metadata_key(key), + Self::BedrockConverse(r) => r.remove_metadata_key(key), + Self::BedrockConverseStream(r) => r.remove_metadata_key(key), } } } @@ -117,21 +138,21 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { } /// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs) -impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { +impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType { type Error = ProviderRequestError; - fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result { - match (request, upstream_api) { + fn try_from((client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs)) -> Result { + match (client_request, upstream_api) { // Same API - no conversion needed, just clone the reference - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { Ok(ProviderRequestType::MessagesRequest(messages_req)) } // Cross-API conversion - cloning is necessary for transformation - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { let messages_req = MessagesRequest::try_from(chat_req) .map_err(|e| ProviderRequestError { message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), @@ -140,7 +161,7 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { Ok(ProviderRequestType::MessagesRequest(messages_req)) } - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + (ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { let chat_req = ChatCompletionsRequest::try_from(messages_req) .map_err(|e| ProviderRequestError { message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), @@ -148,6 +169,41 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { })?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } + + // Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { + let bedrock_req = ConverseRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + + (ProviderRequestType::ChatCompletionsRequest(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + todo!("ChatCompletionsRequest to Amazon Bedrock Stream conversion not implemented yet") + } + (ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { + let bedrock_req = ConverseRequest::try_from(messages_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + (ProviderRequestType::MessagesRequest(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + todo!("MessagesRequest to Amazon Bedrock Stream conversion not implemented yet") + } + + // Amazon Bedrock to other APIs conversions + (ProviderRequestType::BedrockConverse(_), _) => { + todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet") + } + + (ProviderRequestType::BedrockConverseStream(_), _) => { + todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet") + } + } } } @@ -182,7 +238,7 @@ mod tests { use crate::apis::openai::OpenAIApi::ChatCompletions; use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; use crate::apis::openai::{ChatCompletionsRequest}; - use crate::clients::transformer::ExtractText; + use crate::transforms::lib::ExtractText; use serde_json::json; #[test] diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 6bc4e25f..831636a8 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,3 +1,4 @@ +use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::providers::id::ProviderId; use serde::{Serialize, Deserialize}; use std::error::Error; @@ -10,6 +11,7 @@ use crate::apis::openai::ChatCompletionsStreamResponse; use crate::apis::anthropic::MessagesStreamEvent; use crate::clients::endpoints::SupportedAPIs; use crate::apis::anthropic::MessagesResponse; +use crate::apis::amazon_bedrock::ConverseResponse; /// Trait for token usage information pub trait TokenUsage { @@ -30,6 +32,7 @@ pub enum ProviderResponseType { pub enum ProviderStreamResponseType { ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), MessagesStreamEvent(MessagesStreamEvent), + } pub trait ProviderResponse: Send + Sync { @@ -213,19 +216,19 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { - let upstream_api = provider_id.compatible_api_for_client(client_api); + let upstream_api = provider_id.compatible_api_for_client(client_api, false); match (&upstream_api, client_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::ChatCompletionsResponse(resp)) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { let resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::MessagesResponse(resp)) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -234,7 +237,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -243,32 +246,43 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; Ok(ProviderResponseType::MessagesResponse(messages_resp)) } + // Amazon Bedrock transformations + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to OpenAI ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to Anthropic Messages format using the transformer + let messages_resp: MessagesResponse = bedrock_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) + } + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), _) => { + todo!("Amazon Bedrock streaming response transformation not implemented yet") + } } } } // Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType { +impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { type Error = Box; - fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result { + fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { match (upstream_api, client_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) - } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; - - // Transform to OpenAI ChatCompletions stream format using the transformer - let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp)) - } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion if bytes == b"[DONE]" { return Ok(ProviderStreamResponseType::MessagesStreamEvent( @@ -277,20 +291,45 @@ impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponse } let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - - // Transform to Anthropic Messages stream format using the transformer let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?; Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp)) } + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) + } + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + + // Transform to OpenAI ChatCompletions stream format using the transformer + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp)) + } + + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + todo!("Amazon Bedrock to OpenAI streaming transformation not implemented yet") + } + + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + todo!("Anthropic to Amazon Bedrock streaming transformation not implemented yet") + } + + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + todo!("Amazon Bedrock streaming response transformation not implemented yet") + } + + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + todo!("Amazon Bedrock streaming response transformation not implemented yet") + } } } } // TryFrom implementation to convert raw bytes to SseEvent with parsed provider response -impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { +impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { type Error = Box; - fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result { + fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { // Create a new transformed event based on the original let mut transformed_event = sse_event; @@ -305,13 +344,31 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { } match (client_api, upstream_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { // No transformation needed } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { // No transformation needed } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + + (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { + // This should never get called since we are in the streaming path + + } + + (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + // TODO: Implement OpenAI to Amazon Bedrock SSE transformation + } + + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + // TODO: Implement Anthropic to Amazon Bedrock SSE transformation + } + + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { + // TODO: Implement Anthropic to Amazon Bedrock SSE transformation + } + + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { if let Some(provider_response) = &transformed_event.provider_stream_response { if let Some(event_type) = provider_response.event_type() { // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) @@ -351,7 +408,7 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { // This handles cases where the transformation might not produce a valid event type } } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { if transformed_event.is_event_only() && transformed_event.event.is_some() { transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI } @@ -401,13 +458,15 @@ where } // TryFrom implementation to parse bytes into SseStreamIter +// Handles both text-based SSE and binary AWS Event Stream formats impl TryFrom<&[u8]> for SseStreamIter> { type Error = Box; fn try_from(bytes: &[u8]) -> Result { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - Ok(SseStreamIter::new(lines.into_iter())) + // Parse as text-based SSE format + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + Ok(SseStreamIter::new(lines.into_iter())) } } @@ -806,7 +865,7 @@ mod tests { // Test that [DONE] marker is properly converted to MessageStop in the transformation layer let done_bytes = b"[DONE]"; let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api)); assert!(result.is_ok()); diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs new file mode 100644 index 00000000..9f2c49ed --- /dev/null +++ b/crates/hermesllm/src/transforms/lib.rs @@ -0,0 +1,172 @@ +use serde_json::Value; +use crate::apis::anthropic::{MessagesContentBlock,MessagesImageSource}; +use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall}; +use crate::clients::TransformError; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub trait ExtractText { + fn extract_text(&self) -> String; +} + +/// Trait for utility functions on content collections +pub trait ContentUtils { + fn extract_tool_calls(&self) -> Result>, TransformError>; + fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; +} + +/// Helper to create a current unix timestamp +pub fn current_timestamp() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() +} + +// Content Utilities +impl ContentUtils for Vec { + fn extract_tool_calls(&self) -> Result>, TransformError> { + let mut tool_calls = Vec::new(); + + for block in self { + match block { + MessagesContentBlock::ToolUse { id, name, input, .. } | + MessagesContentBlock::ServerToolUse { id, name, input } | + MessagesContentBlock::McpToolUse { id, name, input } => { + let arguments = serde_json::to_string(&input)?; + tool_calls.push(ToolCall { + id: id.clone(), + call_type: "function".to_string(), + function: FunctionCall { name: name.clone(), arguments }, + }); + } + _ => continue, + } + } + + Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) }) + } + + fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> { + let mut content_parts = Vec::new(); + let mut tool_calls = Vec::new(); + let mut tool_results = Vec::new(); + + for block in self { + match block { + MessagesContentBlock::Text { text, .. } => { + content_parts.push(ContentPart::Text { text: text.clone() }); + } + MessagesContentBlock::Image { source } => { + let url = convert_image_source_to_url(source); + content_parts.push(ContentPart::ImageUrl { + image_url: ImageUrl { + url, + detail: Some("auto".to_string()), + }, + }); + } + MessagesContentBlock::ToolUse { id, name, input, .. } | + MessagesContentBlock::ServerToolUse { id, name, input } | + MessagesContentBlock::McpToolUse { id, name, input } => { + let arguments = serde_json::to_string(&input)?; + tool_calls.push(ToolCall { + id: id.clone(), + call_type: "function".to_string(), + function: FunctionCall { name: name.clone(), arguments }, + }); + } + MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => { + let result_text = content.extract_text(); + tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); + } + MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } | + MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } | + MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => { + let result_text = content.extract_text(); + tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); + } + _ => { + // Skip unsupported content types + continue; + } + } + } + + Ok((content_parts, tool_calls, tool_results)) + } +} + +/// Convert image source to URL +pub fn convert_image_source_to_url(source: &MessagesImageSource) -> String { + match source { + MessagesImageSource::Base64 { media_type, data } => { + format!("data:{};base64,{}", media_type, data) + } + MessagesImageSource::Url { url } => url.clone(), + } +} + +/// Convert image URL to Anthropic image source +fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource { + if image_url.url.starts_with("data:") { + // Parse data URL + let parts: Vec<&str> = image_url.url.splitn(2, ',').collect(); + if parts.len() == 2 { + let header = parts[0]; + let data = parts[1]; + let media_type = header + .strip_prefix("data:") + .and_then(|s| s.split(';').next()) + .unwrap_or("image/jpeg") + .to_string(); + + MessagesImageSource::Base64 { + media_type, + data: data.to_string(), + } + } else { + MessagesImageSource::Url { url: image_url.url.clone() } + } + } else { + MessagesImageSource::Url { url: image_url.url.clone() } + } +} + +/// Convert OpenAI message to Anthropic content blocks +pub fn convert_openai_message_to_anthropic_content(message: &Message) -> Result, TransformError> { + let mut blocks = Vec::new(); + + // Handle regular content + match &message.content { + MessageContent::Text(text) => { + if !text.is_empty() { + blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); + } + } + MessageContent::Parts(parts) => { + for part in parts { + match part { + ContentPart::Text { text } => { + blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); + } + ContentPart::ImageUrl { image_url } => { + let source = convert_image_url_to_source(image_url); + blocks.push(MessagesContentBlock::Image { source }); + } + } + } + } + } + + // Handle tool calls + if let Some(tool_calls) = &message.tool_calls { + for tool_call in tool_calls { + let input: Value = serde_json::from_str(&tool_call.function.arguments)?; + blocks.push(MessagesContentBlock::ToolUse { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + input, + cache_control: None, + }); + } + } + + Ok(blocks) +} diff --git a/crates/hermesllm/src/transforms/mod.rs b/crates/hermesllm/src/transforms/mod.rs new file mode 100644 index 00000000..256c319b --- /dev/null +++ b/crates/hermesllm/src/transforms/mod.rs @@ -0,0 +1,25 @@ +//! API transformation modules +//! +//! This module provides organized transformations between the two main LLM API formats: +//! - `/v1/chat/completions` (OpenAI format) +//! - `/v1/messages` (Anthropic format) +//! +//! Provider-specific transformations (Bedrock, Groq, etc.) are handled internally +//! by the gateway, but the external API surface remains these two standard formats. +//! The transformations are split into logical modules for maintainability. + +pub mod request; +pub mod response; +pub mod lib; + +// Re-export commonly used items for convenience +pub use request::*; +pub use response::*; +pub use lib::*; + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified +pub const DEFAULT_MAX_TOKENS: u32 = 4096; diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs new file mode 100644 index 00000000..38d39979 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -0,0 +1,665 @@ +use crate::transforms::lib::*; +use crate::clients::TransformError; +use crate::apis::anthropic::{MessagesMessage, MessagesRequest, MessagesMessageContent, MessagesRole, MessagesStopReason, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage, MessagesSystemPrompt, ToolResultContent}; +use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Function, FunctionChoice,FinishReason, Usage, ContentPart}; +use crate::apis::amazon_bedrock::{ + ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration, + Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition, + AutoChoice, AnyChoice, ToolChoiceSpec, + Message as BedrockMessage, ConversationRole, ContentBlock, + ToolUseBlock, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ImageBlock, ImageSource +}; + +type AnthropicMessagesRequest = MessagesRequest; + +// Conversion from Anthropic MessagesRequest to OpenAI ChatCompletionsRequest +impl TryFrom for ChatCompletionsRequest { + type Error = TransformError; + + fn try_from(req: AnthropicMessagesRequest) -> Result { + let mut openai_messages: Vec = Vec::new(); + + // Convert system prompt to system message if present + if let Some(system) = req.system { + openai_messages.push(system.into()); + } + + // Convert messages + for message in req.messages { + let converted_messages: Vec = message.try_into()?; + openai_messages.extend(converted_messages); + } + + // Convert tools and tool choice + let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); + let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice); + + let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { + model: req.model, + messages: openai_messages, + temperature: req.temperature, + top_p: req.top_p, + max_completion_tokens: Some(req.max_tokens), + stream: req.stream, + stop: req.stop_sequences, + tools: openai_tools, + tool_choice: openai_tool_choice, + parallel_tool_calls, + ..Default::default() + }; + _chat_completions_req.suppress_max_tokens_if_o3(); + _chat_completions_req.fix_temperature_if_gpt5(); + Ok(_chat_completions_req) + } +} + +// Conversion from Anthropic MessagesRequest to Amazon Bedrock ConverseRequest +impl TryFrom for ConverseRequest { + type Error = TransformError; + + fn try_from(req: AnthropicMessagesRequest) -> Result { + // Convert system prompt to SystemContentBlock if present + let system: Option> = req.system.map(|system_prompt| { + let text = match system_prompt { + MessagesSystemPrompt::Single(text) => text, + MessagesSystemPrompt::Blocks(blocks) => blocks.extract_text(), + }; + vec![SystemContentBlock::Text { text }] + }); + + // Convert messages to Bedrock format + let messages = if req.messages.is_empty() { + None + } else { + let mut bedrock_messages = Vec::new(); + for anthropic_message in req.messages { + let bedrock_message: BedrockMessage = anthropic_message.try_into()?; + bedrock_messages.push(bedrock_message); + } + Some(bedrock_messages) + }; + + // Build inference configuration + // Anthropic always requires max_tokens, so we should always include inferenceConfig + let inference_config = Some(InferenceConfiguration { + max_tokens: Some(req.max_tokens), + temperature: req.temperature, + top_p: req.top_p, + stop_sequences: req.stop_sequences, + }); + + // Convert tools and tool choice to ToolConfiguration + let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { + let tools = req.tools.map(|anthropic_tools| { + anthropic_tools.into_iter() + .map(|tool| BedrockTool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: tool.name, + description: tool.description, + input_schema: ToolInputSchema { + json: tool.input_schema, + }, + }, + }) + .collect() + }); + + let tool_choice = req.tool_choice.map(|choice| { + match choice.kind { + MessagesToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} }, + MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} }, + MessagesToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none" + MessagesToolChoiceType::Tool => { + if let Some(name) = choice.name { + BedrockToolChoice::Tool { + tool: ToolChoiceSpec { name } + } + } else { + BedrockToolChoice::Auto { auto: AutoChoice {} } + } + } + } + }); + + Some(ToolConfiguration { + tools, + tool_choice, + }) + } else { + None + }; + + Ok(ConverseRequest { + model_id: req.model, + messages, + system, + inference_config, + tool_config, + stream: req.stream.unwrap_or(false), + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + }) + } +} + + +// Message Conversions +impl TryFrom for Vec { + type Error = TransformError; + + fn try_from(message: MessagesMessage) -> Result { + let mut result = Vec::new(); + + match message.content { + MessagesMessageContent::Single(text) => { + result.push(Message { + role: message.role.into(), + content: MessageContent::Text(text), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + MessagesMessageContent::Blocks(blocks) => { + let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?; + // Add tool result messages + for (tool_use_id, result_text, _is_error) in tool_results { + result.push(Message { + role: Role::Tool, + content: MessageContent::Text(result_text), + name: None, + tool_calls: None, + tool_call_id: Some(tool_use_id), + }); + } + + // Only create main message if there's actual content or tool calls + // Skip creating empty content messages (e.g., when message only contains tool_result blocks) + if !content_parts.is_empty() || !tool_calls.is_empty() { + let content = build_openai_content(content_parts, &tool_calls); + let main_message = Message { + role: message.role.into(), + content, + name: None, + tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, + tool_call_id: None, + }; + result.push(main_message); + } + } + } + + Ok(result) + } +} + + +// Role Conversions +impl Into for MessagesRole { + fn into(self) -> Role { + match self { + MessagesRole::User => Role::User, + MessagesRole::Assistant => Role::Assistant, + } + } +} + +impl Into for FinishReason { + fn into(self) -> MessagesStopReason { + match self { + FinishReason::Stop => MessagesStopReason::EndTurn, + FinishReason::Length => MessagesStopReason::MaxTokens, + FinishReason::ToolCalls => MessagesStopReason::ToolUse, + FinishReason::ContentFilter => MessagesStopReason::Refusal, + FinishReason::FunctionCall => MessagesStopReason::ToolUse, + } + } +} + +impl Into for Usage { + fn into(self) -> MessagesUsage { + MessagesUsage { + input_tokens: self.prompt_tokens, + output_tokens: self.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + } + } +} + +// System Prompt Conversions +impl Into for MessagesSystemPrompt { + fn into(self) -> Message { + let system_content = match self { + MessagesSystemPrompt::Single(text) => MessageContent::Text(text), + MessagesSystemPrompt::Blocks(blocks) => { + MessageContent::Text(blocks.extract_text()) + } + }; + + Message { + role: Role::System, + content: system_content, + name: None, + tool_calls: None, + tool_call_id: None, + } + } +} + +//Utility Functions +/// Convert Anthropic tools to OpenAI format +fn convert_anthropic_tools(tools: Vec) -> Vec { + tools.into_iter() + .map(|tool| Tool { + tool_type: "function".to_string(), + function: Function { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + strict: None, + }, + }) + .collect() +} + +/// Convert Anthropic tool choice to OpenAI format +fn convert_anthropic_tool_choice(tool_choice: Option) -> (Option, Option) { + match tool_choice { + Some(choice) => { + let openai_choice = match choice.kind { + MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto), + MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required), + MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None), + MessagesToolChoiceType::Tool => { + if let Some(name) = choice.name { + ToolChoice::Function { + choice_type: "function".to_string(), + function: FunctionChoice { name }, + } + } else { + ToolChoice::Type(ToolChoiceType::Auto) + } + } + }; + let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); + (Some(openai_choice), parallel) + } + None => (None, None) + } +} + +/// Build OpenAI message content from parts and tool calls +fn build_openai_content(content_parts: Vec, tool_calls: &[ToolCall]) -> MessageContent { + if content_parts.len() == 1 && tool_calls.is_empty() { + match &content_parts[0] { + ContentPart::Text { text } => MessageContent::Text(text.clone()), + _ => MessageContent::Parts(content_parts), + } + } else if content_parts.is_empty() { + MessageContent::Text("".to_string()) + } else { + MessageContent::Parts(content_parts) + } +} + +impl TryFrom for BedrockMessage { + type Error = TransformError; + + fn try_from(message: MessagesMessage) -> Result { + let role = match message.role { + MessagesRole::User => ConversationRole::User, + MessagesRole::Assistant => ConversationRole::Assistant, + }; + + let mut content_blocks = Vec::new(); + + // Convert content blocks + match message.content { + MessagesMessageContent::Single(text) => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + MessagesMessageContent::Blocks(blocks) => { + for block in blocks { + match block { + crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + crate::apis::anthropic::MessagesContentBlock::ToolUse { id, name, input, .. } => { + content_blocks.push(ContentBlock::ToolUse { + tool_use: ToolUseBlock { + tool_use_id: id, + name, + input, + }, + }); + } + crate::apis::anthropic::MessagesContentBlock::ToolResult { tool_use_id, is_error, content, .. } => { + // Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock + let tool_result_content = match content { + ToolResultContent::Text(text) => { + vec![ToolResultContentBlock::Text { text }] + } + ToolResultContent::Blocks(blocks) => { + let mut result_blocks = Vec::new(); + for result_block in blocks { + match result_block { + crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { + result_blocks.push(ToolResultContentBlock::Text { text }); + } + // For now, skip other content types in tool results + _ => {} + } + } + result_blocks + } + }; + + // Ensure we have at least one content block + let final_content = if tool_result_content.is_empty() { + vec![ToolResultContentBlock::Text { text: " ".to_string() }] + } else { + tool_result_content + }; + + let status = if is_error.unwrap_or(false) { + Some(ToolResultStatus::Error) + } else { + Some(ToolResultStatus::Success) + }; + + content_blocks.push(ContentBlock::ToolResult { + tool_result: ToolResultBlock { + tool_use_id, + content: final_content, + status, + }, + }); + } + crate::apis::anthropic::MessagesContentBlock::Image { source } => { + // Convert Anthropic image to Bedrock image format + match source { + crate::apis::anthropic::MessagesImageSource::Base64 { media_type, data } => { + content_blocks.push(ContentBlock::Image { + image: ImageBlock { + source: ImageSource::Base64 { + media_type, + data, + }, + }, + }); + } + crate::apis::anthropic::MessagesImageSource::Url { .. } => { + // Bedrock doesn't support URL-based images, skip for now + // Could potentially download and convert to base64, but not implemented + } + } + } + // Skip other content types for now (Thinking, Document, etc.) + _ => {} + } + } + } + } + + // Ensure we have at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { text: " ".to_string() }); + } + + Ok(BedrockMessage { + role, + content: content_blocks, + }) + } +} + + + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::anthropic::{MessagesRequest, MessagesMessage, MessagesMessageContent, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesSystemPrompt}; + use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ToolChoice as BedrockToolChoice, ConversationRole, ContentBlock}; + use serde_json::json; + + #[test] + fn test_anthropic_to_bedrock_basic_request() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![ + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello, how are you?".to_string()), + } + ], + max_tokens: 1000, + container: None, + mcp_servers: None, + system: Some(MessagesSystemPrompt::Single("You are a helpful assistant.".to_string())), + metadata: None, + service_tier: None, + thinking: None, + temperature: Some(0.7), + top_p: Some(0.9), + top_k: None, + stream: Some(false), + stop_sequences: Some(vec!["STOP".to_string()]), + tools: None, + tool_choice: None, + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022"); + assert!(bedrock_request.system.is_some()); + assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1); + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, ConversationRole::User); + + if let ContentBlock::Text { text } = &messages[0].content[0] { + assert_eq!(text, "Hello, how are you?"); + } else { + panic!("Expected text content block"); + } + + let inference_config = bedrock_request.inference_config.as_ref().unwrap(); + assert_eq!(inference_config.temperature, Some(0.7)); + assert_eq!(inference_config.top_p, Some(0.9)); + assert_eq!(inference_config.max_tokens, Some(1000)); + assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()])); + } + + #[test] + fn test_anthropic_to_bedrock_with_tools() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![ + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("What's the weather like?".to_string()), + } + ], + max_tokens: 1000, + container: None, + mcp_servers: None, + system: None, + metadata: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: Some(vec![ + MessagesTool { + name: "get_weather".to_string(), + description: Some("Get current weather information".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + } + }, + "required": ["location"] + }), + } + ]), + tool_choice: Some(MessagesToolChoice { + kind: MessagesToolChoiceType::Tool, + name: Some("get_weather".to_string()), + disable_parallel_tool_use: None, + }), + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022"); + assert!(bedrock_request.tool_config.is_some()); + + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(tool_config.tools.is_some()); + let tools = tool_config.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; + assert_eq!(tool_spec.name, "get_weather"); + assert_eq!(tool_spec.description, Some("Get current weather information".to_string())); + + if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { + assert_eq!(tool.name, "get_weather"); + } else { + panic!("Expected specific tool choice"); + } + } + + #[test] + fn test_anthropic_to_bedrock_auto_tool_choice() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![ + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Help me with something".to_string()), + } + ], + max_tokens: 500, + container: None, + mcp_servers: None, + system: None, + metadata: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: Some(vec![ + MessagesTool { + name: "help_tool".to_string(), + description: Some("A helpful tool".to_string()), + input_schema: json!({ + "type": "object", + "properties": {} + }), + } + ]), + tool_choice: Some(MessagesToolChoice { + kind: MessagesToolChoiceType::Auto, + name: None, + disable_parallel_tool_use: None, + }), + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert!(bedrock_request.tool_config.is_some()); + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. }))); + } + + #[test] + fn test_anthropic_to_bedrock_multi_message_conversation() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![ + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello".to_string()), + }, + MessagesMessage { + role: MessagesRole::Assistant, + content: MessagesMessageContent::Single("Hi there! How can I help you?".to_string()), + }, + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("What's 2+2?".to_string()), + } + ], + max_tokens: 100, + container: None, + mcp_servers: None, + system: Some(MessagesSystemPrompt::Single("Be concise".to_string())), + metadata: None, + service_tier: None, + thinking: None, + temperature: Some(0.5), + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: None, + tool_choice: None, + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].role, ConversationRole::User); + assert_eq!(messages[1].role, ConversationRole::Assistant); + assert_eq!(messages[2].role, ConversationRole::User); + + // Check system prompt + assert!(bedrock_request.system.is_some()); + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "Be concise"); + } else { + panic!("Expected system text block"); + } + } + + #[test] + fn test_anthropic_message_to_bedrock_conversion() { + let anthropic_message = MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Test message".to_string()), + }; + + let bedrock_message: BedrockMessage = anthropic_message.try_into().unwrap(); + + assert_eq!(bedrock_message.role, ConversationRole::User); + assert_eq!(bedrock_message.content.len(), 1); + + if let ContentBlock::Text { text } = &bedrock_message.content[0] { + assert_eq!(text, "Test message"); + } else { + panic!("Expected text content block"); + } + } +} diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs new file mode 100644 index 00000000..a40950a7 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -0,0 +1,732 @@ +use crate::transforms::lib::ExtractText; +use crate::transforms::lib::*; +use crate::clients::TransformError; +use crate::transforms::*; +use crate::apis::anthropic::{MessagesSystemPrompt, MessagesMessage,MessagesRequest, MessagesMessageContent, MessagesContentBlock, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, ToolResultContent}; +use crate::apis::openai::{ChatCompletionsRequest, Message, Role, Tool, ToolChoice, ToolChoiceType, MessageContent}; +use crate::apis::amazon_bedrock::{ + ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration, + Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition, + AutoChoice, AnyChoice, ToolChoiceSpec, + Message as BedrockMessage, ConversationRole, ContentBlock +}; + +type AnthropicMessagesRequest = MessagesRequest; + +// ============================================================================ +// MAIN REQUEST TRANSFORMATIONS +// ============================================================================ + +impl Into for Message { + fn into(self) -> MessagesSystemPrompt { + let system_text = match self.content { + MessageContent::Text(text) => text, + MessageContent::Parts(parts) => parts.extract_text() + }; + MessagesSystemPrompt::Single(system_text) + } +} + +impl TryFrom for MessagesMessage { + type Error = TransformError; + + fn try_from(message: Message) -> Result { + let role = match message.role { + Role::User => MessagesRole::User, + Role::Assistant => MessagesRole::Assistant, + Role::Tool => { + // Tool messages become user messages with tool results + let tool_call_id = message.tool_call_id + .ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; + + return Ok(MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Blocks(vec![ + MessagesContentBlock::ToolResult { + tool_use_id: tool_call_id, + is_error: None, + content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { + text: message.content.extract_text(), + cache_control: None, + }]), + cache_control: None, + }, + ]), + }); + } + Role::System => { + return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string())); + } + }; + + let content_blocks = convert_openai_message_to_anthropic_content(&message)?; + let content = build_anthropic_content(content_blocks); + + Ok(MessagesMessage { role, content }) + } +} + +impl TryFrom for BedrockMessage { + type Error = TransformError; + + fn try_from(message: Message) -> Result { + let role = match message.role { + Role::User => ConversationRole::User, + Role::Assistant => ConversationRole::Assistant, + Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock + Role::System => { + return Err(TransformError::UnsupportedConversion("System messages should be handled separately in Bedrock".to_string())); + } + }; + + let mut content_blocks = Vec::new(); + + // Handle different message types + match message.role { + Role::User => { + // Convert user message content to content blocks + match message.content { + MessageContent::Text(text) => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + MessageContent::Parts(parts) => { + // Convert OpenAI content parts to Bedrock ContentBlocks + for part in parts { + match part { + crate::apis::openai::ContentPart::Text { text } => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + crate::apis::openai::ContentPart::ImageUrl { image_url } => { + // Convert image URL to Bedrock image format + if image_url.url.starts_with("data:") { + if let Some((media_type, data)) = parse_data_url(&image_url.url) { + content_blocks.push(ContentBlock::Image { + image: crate::apis::amazon_bedrock::ImageBlock { + source: crate::apis::amazon_bedrock::ImageSource::Base64 { + media_type, + data, + }, + }, + }); + } else { + return Err(TransformError::UnsupportedConversion( + format!("Invalid data URL format: {}", image_url.url) + )); + } + } else { + return Err(TransformError::UnsupportedConversion( + "Only base64 data URLs are supported for images in Bedrock".to_string() + )); + } + } + } + } + } + } + + // Ensure we have at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { text: " ".to_string() }); + } + } + Role::Assistant => { + // Handle text content - but only add if non-empty OR if we don't have tool calls + let text_content = message.content.extract_text(); + let has_tool_calls = message.tool_calls.as_ref().map_or(false, |calls| !calls.is_empty()); + + // Add text content if it's non-empty, or if we have no tool calls (to avoid empty content) + if !text_content.is_empty() { + content_blocks.push(ContentBlock::Text { text: text_content }); + } else if !has_tool_calls { + // If we have empty content and no tool calls, add a minimal placeholder + // This prevents the "blank text field" error + content_blocks.push(ContentBlock::Text { text: " ".to_string() }); + } + + // Convert tool calls to ToolUse content blocks + if let Some(tool_calls) = message.tool_calls { + for tool_call in tool_calls { + // Parse the arguments string as JSON + let input: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .map_err(|e| TransformError::UnsupportedConversion( + format!("Failed to parse tool arguments as JSON: {}. Arguments: {}", e, tool_call.function.arguments) + ))?; + + content_blocks.push(ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: tool_call.id, + name: tool_call.function.name, + input, + }, + }); + } + } + + // Bedrock requires at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { text: " ".to_string() }); + } + } + Role::Tool => { + // Tool messages become user messages with ToolResult content blocks + let tool_call_id = message.tool_call_id + .ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; + + let tool_content = message.content.extract_text(); + + // Create ToolResult content block + let tool_result_content = if tool_content.is_empty() { + // Even for tool results, we need non-empty content + vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { + text: " ".to_string() + }] + } else { + vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { + text: tool_content + }] + }; + + content_blocks.push(ContentBlock::ToolResult { + tool_result: crate::apis::amazon_bedrock::ToolResultBlock { + tool_use_id: tool_call_id, + content: tool_result_content, + status: Some(crate::apis::amazon_bedrock::ToolResultStatus::Success), // Default to success + }, + }); + } + Role::System => { + // Already handled above with early return + unreachable!() + } + } + + Ok(BedrockMessage { + role, + content: content_blocks, + }) + } +} + +impl TryFrom for AnthropicMessagesRequest { + type Error = TransformError; + + fn try_from(req: ChatCompletionsRequest) -> Result { + let mut system_prompt = None; + let mut messages = Vec::new(); + + for message in req.messages { + match message.role { + Role::System => { + system_prompt = Some(message.into()); + } + _ => { + let anthropic_message: MessagesMessage = message.try_into()?; + messages.push(anthropic_message); + } + } + } + + // Convert tools and tool choice + let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); + let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); + + Ok(AnthropicMessagesRequest { + model: req.model, + system: system_prompt, + messages, + max_tokens: req.max_completion_tokens + .or(req.max_tokens) + .unwrap_or(DEFAULT_MAX_TOKENS), + container: None, + mcp_servers: None, + service_tier: None, + thinking: None, + temperature: req.temperature, + top_p: req.top_p, + top_k: None, // OpenAI doesn't have top_k + stream: req.stream, + stop_sequences: req.stop, + tools: anthropic_tools, + tool_choice: anthropic_tool_choice, + metadata: None, + }) + } +} + +impl TryFrom for ConverseRequest { + type Error = TransformError; + + fn try_from(req: ChatCompletionsRequest) -> Result { + // Separate system messages from user/assistant messages + let mut system_messages = Vec::new(); + let mut conversation_messages = Vec::new(); + + for message in req.messages { + match message.role { + Role::System => { + let system_text = match message.content { + MessageContent::Text(text) => text, + MessageContent::Parts(parts) => parts.extract_text(), + }; + system_messages.push(SystemContentBlock::Text { text: system_text }); + } + _ => { + let bedrock_message: BedrockMessage = message.try_into()?; + conversation_messages.push(bedrock_message); + } + } + } + + // Convert system messages + let system = if system_messages.is_empty() { + None + } else { + Some(system_messages) + }; + + // Convert conversation messages + let messages = if conversation_messages.is_empty() { + None + } else { + Some(conversation_messages) + }; + + // Build inference configuration + let max_tokens = req.max_completion_tokens.or(req.max_tokens); + let inference_config = if max_tokens.is_some() || req.temperature.is_some() || + req.top_p.is_some() || req.stop.is_some() { + Some(InferenceConfiguration { + max_tokens, + temperature: req.temperature, + top_p: req.top_p, + stop_sequences: req.stop, + }) + } else { + None + }; + + // Convert tools and tool choice to ToolConfiguration + let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { + let tools = req.tools.map(|openai_tools| { + openai_tools.into_iter() + .map(|tool| BedrockTool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: tool.function.name, + description: tool.function.description, + input_schema: ToolInputSchema { + json: tool.function.parameters, + }, + }, + }) + .collect() + }); + + let tool_choice = req.tool_choice.map(|choice| { + match choice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} }, + ToolChoiceType::Required => BedrockToolChoice::Any { any: AnyChoice {} }, + ToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none" + }, + ToolChoice::Function { function, .. } => { + BedrockToolChoice::Tool { + tool: ToolChoiceSpec { + name: function.name + } + } + } + } + }).or_else(|| { + // If tools are present but no tool_choice specified, default to "auto" + if tools.is_some() { + Some(BedrockToolChoice::Auto { auto: AutoChoice {} }) + } else { + None + } + }); + + Some(ToolConfiguration { + tools, + tool_choice, + }) + } else { + None + }; + + Ok(ConverseRequest { + model_id: req.model, + messages, + system, + inference_config, + tool_config, + stream: req.stream.unwrap_or(false), + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + }) + } +} + +/// Convert OpenAI tools to Anthropic format +fn convert_openai_tools(tools: Vec) -> Vec { + tools.into_iter() + .map(|tool| MessagesTool { + name: tool.function.name, + description: tool.function.description, + input_schema: tool.function.parameters, + }) + .collect() +} + + +/// Convert OpenAI tool choice to Anthropic format +fn convert_openai_tool_choice( + tool_choice: Option, + parallel_tool_calls: Option +) -> Option { + tool_choice.map(|choice| { + match choice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => MessagesToolChoice { + kind: MessagesToolChoiceType::Auto, + name: None, + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + ToolChoiceType::Required => MessagesToolChoice { + kind: MessagesToolChoiceType::Any, + name: None, + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + ToolChoiceType::None => MessagesToolChoice { + kind: MessagesToolChoiceType::None, + name: None, + disable_parallel_tool_use: None, + }, + }, + ToolChoice::Function { function, .. } => MessagesToolChoice { + kind: MessagesToolChoiceType::Tool, + name: Some(function.name), + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + } + }) +} + +/// Build Anthropic message content from content blocks +fn build_anthropic_content(content_blocks: Vec) -> MessagesMessageContent { + if content_blocks.len() == 1 { + match &content_blocks[0] { + MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()), + _ => MessagesMessageContent::Blocks(content_blocks), + } + } else if content_blocks.is_empty() { + MessagesMessageContent::Single("".to_string()) + } else { + MessagesMessageContent::Blocks(content_blocks) + } +} + + + +/// Parse a data URL into media type and base64 data +/// Supports format: data:image/jpeg;base64, +fn parse_data_url(url: &str) -> Option<(String, String)> { + if !url.starts_with("data:") { + return None; + } + + let without_prefix = &url[5..]; // Remove "data:" prefix + let parts: Vec<&str> = without_prefix.splitn(2, ',').collect(); + + if parts.len() != 2 { + return None; + } + + let header = parts[0]; + let data = parts[1]; + + // Parse header: "image/jpeg;base64" or just "image/jpeg" + let header_parts: Vec<&str> = header.split(';').collect(); + if header_parts.is_empty() { + return None; + } + + let media_type = header_parts[0].to_string(); + + // Check if it's base64 encoded + if header_parts.len() > 1 && header_parts[1] == "base64" { + Some((media_type, data.to_string())) + } else { + // For now, only support base64 encoding + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, Function, FunctionChoice}; + use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ConversationRole, ContentBlock, ToolChoice as BedrockToolChoice}; + use serde_json::json; + + #[test] + fn test_openai_to_bedrock_basic_request() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("You are a helpful assistant.".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello, how are you?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + } + ], + temperature: Some(0.7), + top_p: Some(0.9), + max_completion_tokens: Some(1000), + stop: Some(vec!["STOP".to_string()]), + stream: Some(false), + tools: None, + tool_choice: None, + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "gpt-4"); + assert!(bedrock_request.system.is_some()); + assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1); + + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "You are a helpful assistant."); + } else { + panic!("Expected system text block"); + } + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, ConversationRole::User); + + if let ContentBlock::Text { text } = &messages[0].content[0] { + assert_eq!(text, "Hello, how are you?"); + } else { + panic!("Expected text content block"); + } + + let inference_config = bedrock_request.inference_config.as_ref().unwrap(); + assert_eq!(inference_config.temperature, Some(0.7)); + assert_eq!(inference_config.top_p, Some(0.9)); + assert_eq!(inference_config.max_tokens, Some(1000)); + assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()])); + } + + #[test] + fn test_openai_to_bedrock_with_tools() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::User, + content: MessageContent::Text("What's the weather like?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + } + ], + temperature: None, + top_p: None, + max_completion_tokens: Some(1000), + stop: None, + stream: None, + tools: Some(vec![ + Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get current weather information".to_string()), + parameters: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + } + }, + "required": ["location"] + }), + strict: None, + }, + } + ]), + tool_choice: Some(ToolChoice::Function { + choice_type: "function".to_string(), + function: FunctionChoice { name: "get_weather".to_string() }, + }), + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "gpt-4"); + assert!(bedrock_request.tool_config.is_some()); + + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(tool_config.tools.is_some()); + let tools = tool_config.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + + let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; + assert_eq!(tool_spec.name, "get_weather"); + assert_eq!(tool_spec.description, Some("Get current weather information".to_string())); + + if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { + assert_eq!(tool.name, "get_weather"); + } else { + panic!("Expected specific tool choice"); + } + } + + #[test] + fn test_openai_to_bedrock_auto_tool_choice() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::User, + content: MessageContent::Text("Help me with something".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + } + ], + temperature: None, + top_p: None, + max_completion_tokens: Some(500), + stop: None, + stream: None, + tools: Some(vec![ + Tool { + tool_type: "function".to_string(), + function: Function { + name: "help_tool".to_string(), + description: Some("A helpful tool".to_string()), + parameters: json!({ + "type": "object", + "properties": {} + }), + strict: None, + }, + } + ]), + tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)), + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert!(bedrock_request.tool_config.is_some()); + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. }))); + } + + #[test] + fn test_openai_to_bedrock_multi_message_conversation() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("Be concise".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::Assistant, + content: MessageContent::Text("Hi there! How can I help you?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("What's 2+2?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + } + ], + temperature: Some(0.5), + top_p: None, + max_completion_tokens: Some(100), + stop: None, + stream: None, + tools: None, + tool_choice: None, + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 3); // System message is separate + assert_eq!(messages[0].role, ConversationRole::User); + assert_eq!(messages[1].role, ConversationRole::Assistant); + assert_eq!(messages[2].role, ConversationRole::User); + + // Check system prompt + assert!(bedrock_request.system.is_some()); + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "Be concise"); + } else { + panic!("Expected system text block"); + } + } + + #[test] + fn test_openai_message_to_bedrock_conversion() { + let openai_message = Message { + role: Role::User, + content: MessageContent::Text("Test message".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }; + + let bedrock_message: BedrockMessage = openai_message.try_into().unwrap(); + + assert_eq!(bedrock_message.role, ConversationRole::User); + assert_eq!(bedrock_message.content.len(), 1); + + if let ContentBlock::Text { text } = &bedrock_message.content[0] { + assert_eq!(text, "Test message"); + } else { + panic!("Expected text content block"); + } + } +} diff --git a/crates/hermesllm/src/transforms/request/mod.rs b/crates/hermesllm/src/transforms/request/mod.rs new file mode 100644 index 00000000..5fbdf0b1 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/mod.rs @@ -0,0 +1,4 @@ +//! Request transformation modules + +pub mod from_anthropic; +pub mod from_openai; diff --git a/crates/hermesllm/src/transforms/response/mod.rs b/crates/hermesllm/src/transforms/response/mod.rs new file mode 100644 index 00000000..3ce75123 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/mod.rs @@ -0,0 +1,3 @@ +//! Response transformation modules +pub mod to_anthropic; +pub mod to_openai; diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs new file mode 100644 index 00000000..f6dc07d5 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -0,0 +1,866 @@ +use serde_json::Value; +use crate::transforms::lib::*; +use crate::clients::TransformError; +use crate::apis::openai::{ + ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta +}; +use crate::apis::anthropic::{ + MessagesStreamEvent, MessagesStopReason, MessagesMessageDelta, MessagesResponse, + MessagesStreamMessage, MessagesUsage, MessagesContentDelta, MessagesRole, MessagesContentBlock +}; +use crate::apis::amazon_bedrock::{ConverseResponse, ConverseOutput, StopReason}; + +// ============================================================================ +// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience +// ============================================================================ + +impl TryFrom for MessagesResponse { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsResponse) -> Result { + let choice = resp.choices.into_iter().next() + .ok_or_else(|| TransformError::MissingField("choices".to_string()))?; + + let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; + let stop_reason = choice.finish_reason + .map(|fr| fr.into()) + .unwrap_or(MessagesStopReason::EndTurn); + + let usage = MessagesUsage { + input_tokens: resp.usage.prompt_tokens, + output_tokens: resp.usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }; + + Ok(MessagesResponse { + id: resp.id, + obj_type: "message".to_string(), + role: MessagesRole::Assistant, + content, + model: resp.model, + stop_reason, + stop_sequence: None, + usage, + container: None, + }) + } +} + + +impl TryFrom for MessagesResponse { + type Error = TransformError; + + fn try_from(resp: ConverseResponse) -> Result { + // Extract the message from the ConverseOutput + let message = match resp.output { + ConverseOutput::Message { message } => message, + }; + + // Convert Bedrock message content to Anthropic content blocks + let content = convert_bedrock_message_to_anthropic_content(&message)?; + + // Convert Bedrock ConversationRole to Anthropic MessagesRole + let role = match message.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant, + }; + + // Convert Bedrock stop reason to Anthropic stop reason + let stop_reason = match resp.stop_reason { + StopReason::EndTurn => MessagesStopReason::EndTurn, + StopReason::ToolUse => MessagesStopReason::ToolUse, + StopReason::MaxTokens => MessagesStopReason::MaxTokens, + StopReason::StopSequence => MessagesStopReason::EndTurn, + StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + // Convert token usage + let usage = MessagesUsage { + input_tokens: resp.usage.input_tokens, + output_tokens: resp.usage.output_tokens, + cache_creation_input_tokens: resp.usage.cache_write_input_tokens, + cache_read_input_tokens: resp.usage.cache_read_input_tokens, + }; + + // Generate a response ID (Bedrock doesn't provide one) + let id = format!("bedrock-{}", std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos()); + + // Extract model ID from trace information if available, otherwise use fallback + let model = resp.trace + .as_ref() + .and_then(|trace| trace.prompt_router.as_ref()) + .map(|router| router.invoked_model_id.clone()) + .unwrap_or_else(|| "bedrock-model".to_string()); + + Ok(MessagesResponse { + id, + obj_type: "message".to_string(), + role, + content, + model, + stop_reason, + stop_sequence: None, // TODO: Could extract from additional_model_response_fields if needed + usage, + container: None, + }) + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsStreamResponse) -> Result { + if resp.choices.is_empty() { + return Ok(MessagesStreamEvent::Ping); + } + + let choice = &resp.choices[0]; + + // Handle final chunk with usage + let has_usage = resp.usage.is_some(); + if let Some(usage) = resp.usage { + if let Some(finish_reason) = &choice.finish_reason { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: usage.into(), + }); + } + } + + // Handle role start + if let Some(Role::Assistant) = choice.delta.role { + return Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: resp.id, + obj_type: "message".to_string(), + role: MessagesRole::Assistant, + content: vec![], + model: resp.model, + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }); + } + + // Handle content delta + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: content.clone(), + }, + }); + } + } + + // Handle tool calls + if let Some(tool_calls) = &choice.delta.tool_calls { + return convert_tool_call_deltas(tool_calls.clone()); + } + + // Handle finish reason - generate MessageDelta only (MessageStop comes later) + if let Some(finish_reason) = &choice.finish_reason { + // If we have usage data, it was already handled above + // If not, we need to generate MessageDelta with default usage + if !has_usage { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }); + } + // If usage was already handled above, we don't need to do anything more here + // MessageStop will be handled when [DONE] is encountered + } + + // Default to ping for unhandled cases + Ok(MessagesStreamEvent::Ping) + } +} + +/// Convert tool call deltas to Anthropic stream events +fn convert_tool_call_deltas(tool_calls: Vec) -> Result { + for tool_call in tool_calls { + if let Some(id) = &tool_call.id { + // Tool call start + if let Some(function) = &tool_call.function { + if let Some(name) = &function.name { + return Ok(MessagesStreamEvent::ContentBlockStart { + index: tool_call.index, + content_block: MessagesContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: Value::Object(serde_json::Map::new()), + cache_control: None, + }, + }); + } + } + } else if let Some(function) = &tool_call.function { + if let Some(arguments) = &function.arguments { + // Tool arguments delta + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: tool_call.index, + delta: MessagesContentDelta::InputJsonDelta { + partial_json: arguments.clone(), + }, + }); + } + } + } + + // Fallback to ping if no valid tool call found + Ok(MessagesStreamEvent::Ping) +} + +/// Convert Bedrock Message to Anthropic content blocks +/// +/// This function handles the conversion between Amazon Bedrock Converse API format +/// and Anthropic's Messages API format. Key differences handled: +/// +/// 1. **Image/Document Sources**: Bedrock supports base64 and S3 locations, while +/// Anthropic supports base64, URLs, and file IDs. Currently only base64 is supported. +/// 2. **Tool Result Status**: Bedrock uses enum status (Success/Error), Anthropic uses +/// boolean is_error field. +/// 3. **Document Names**: Bedrock includes optional document names, Anthropic doesn't. +/// 4. **JSON Content**: Bedrock has native JSON content blocks, converted to text for Anthropic. +/// +/// Note on S3/URL handling: Converting S3 locations or URLs would require async operations +/// to download and convert to base64, which is not implemented in this synchronous function. +fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_bedrock::Message) -> Result, TransformError> { + use crate::apis::amazon_bedrock::ContentBlock; + + let mut content_blocks = Vec::new(); + + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + content_blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + ContentBlock::ToolUse { tool_use } => { + content_blocks.push(MessagesContentBlock::ToolUse { + id: tool_use.tool_use_id.clone(), + name: tool_use.name.clone(), + input: tool_use.input.clone(), + cache_control: None, + }); + } + ContentBlock::ToolResult { tool_result } => { + // Convert tool result content blocks + let mut tool_result_blocks = Vec::new(); + for result_content in &tool_result.content { + match result_content { + crate::apis::amazon_bedrock::ToolResultContentBlock::Text { text } => { + tool_result_blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => { + // Convert Bedrock ImageSource to Anthropic format + match source { + crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => { + tool_result_blocks.push(MessagesContentBlock::Image { + source: crate::apis::anthropic::MessagesImageSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } + // Note: S3Location is not yet implemented in the current Bedrock API definition + // but would need async handling when added + } + } + crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => { + // Convert JSON content to text representation + tool_result_blocks.push(MessagesContentBlock::Text { + text: serde_json::to_string(&json).unwrap_or_default(), + cache_control: None, + }); + } + } + } + + use crate::apis::anthropic::ToolResultContent; + content_blocks.push(MessagesContentBlock::ToolResult { + tool_use_id: tool_result.tool_use_id.clone(), + is_error: tool_result.status.as_ref().map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)), + content: ToolResultContent::Blocks(tool_result_blocks), + cache_control: None, + }); + } + ContentBlock::Image { image } => { + // Convert Bedrock ImageSource to Anthropic format + match &image.source { + crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => { + content_blocks.push(MessagesContentBlock::Image { + source: crate::apis::anthropic::MessagesImageSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } + // Note: S3Location would require async handling if implemented + } + } + ContentBlock::Document { document } => { + // Convert Bedrock DocumentSource to Anthropic format + // Note: Bedrock's 'name' field is lost in conversion as Anthropic doesn't support it + match &document.source { + crate::apis::amazon_bedrock::DocumentSource::Base64 { media_type, data } => { + content_blocks.push(MessagesContentBlock::Document { + source: crate::apis::anthropic::MessagesDocumentSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } + // Note: S3Location would require async handling if implemented + } + } + ContentBlock::GuardContent { guard_content } => { + // Convert guard content to text block + if let Some(guard_text) = &guard_content.text { + content_blocks.push(MessagesContentBlock::Text { + text: guard_text.text.clone(), + cache_control: None, + }); + } + } + } + } + + Ok(content_blocks) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole, + ContentBlock, StopReason, BedrockTokenUsage, ToolResultContentBlock, ToolResultStatus, + ConverseTrace, PromptRouterTrace + }; + use crate::apis::anthropic::{MessagesResponse, MessagesContentBlock, MessagesStopReason, MessagesRole, ToolResultContent}; + use serde_json::json; + + #[test] + fn test_bedrock_to_anthropic_basic_response() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello! How can I help you today?".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 25, + total_tokens: 35, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.obj_type, "message"); + assert_eq!(anthropic_response.role, MessagesRole::Assistant); + assert_eq!(anthropic_response.model, "bedrock-model"); + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::EndTurn); + assert!(anthropic_response.id.starts_with("bedrock-")); + + // Check content + assert_eq!(anthropic_response.content.len(), 1); + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "Hello! How can I help you today?"); + } else { + panic!("Expected text content block"); + } + + // Check usage + assert_eq!(anthropic_response.usage.input_tokens, 10); + assert_eq!(anthropic_response.usage.output_tokens, 25); + assert_eq!(anthropic_response.usage.cache_creation_input_tokens, None); + assert_eq!(anthropic_response.usage.cache_read_input_tokens, None); + } + + #[test] + fn test_bedrock_to_anthropic_with_tool_use() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help you check the weather.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_12345".to_string(), + name: "get_weather".to_string(), + input: json!({ + "location": "San Francisco" + }), + }, + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 15, + output_tokens: 30, + total_tokens: 45, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::ToolUse); + assert_eq!(anthropic_response.content.len(), 2); + + // Check text content + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "I'll help you check the weather."); + } else { + panic!("Expected text content block"); + } + + // Check tool use content + if let MessagesContentBlock::ToolUse { id, name, input, .. } = &anthropic_response.content[1] { + assert_eq!(id, "tool_12345"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco"); + } else { + panic!("Expected tool use content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_stop_reason_conversions() { + let test_cases = vec![ + (StopReason::EndTurn, MessagesStopReason::EndTurn), + (StopReason::ToolUse, MessagesStopReason::ToolUse), + (StopReason::MaxTokens, MessagesStopReason::MaxTokens), + (StopReason::StopSequence, MessagesStopReason::EndTurn), + (StopReason::GuardrailIntervened, MessagesStopReason::Refusal), + (StopReason::ContentFiltered, MessagesStopReason::Refusal), + ]; + + for (bedrock_stop_reason, expected_anthropic_stop_reason) in test_cases { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Test response".to_string(), + } + ], + }, + }, + stop_reason: bedrock_stop_reason, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + assert_eq!(anthropic_response.stop_reason, expected_anthropic_stop_reason); + } + } + + #[test] + fn test_bedrock_to_anthropic_with_cache_tokens() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Cached response".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 100, + output_tokens: 50, + total_tokens: 150, + cache_write_input_tokens: Some(20), + cache_read_input_tokens: Some(10), + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.usage.input_tokens, 100); + assert_eq!(anthropic_response.usage.output_tokens, 50); + assert_eq!(anthropic_response.usage.cache_creation_input_tokens, Some(20)); + assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10)); + } + + #[test] + fn test_bedrock_to_anthropic_with_tool_result() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Here's the weather information:".to_string(), + }, + ContentBlock::ToolResult { + tool_result: crate::apis::amazon_bedrock::ToolResultBlock { + tool_use_id: "tool_67890".to_string(), + content: vec![ + ToolResultContentBlock::Text { + text: "Temperature: 72°F, Sunny".to_string(), + } + ], + status: Some(ToolResultStatus::Success), + }, + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 20, + output_tokens: 35, + total_tokens: 55, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.content.len(), 2); + + // Check text content + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "Here's the weather information:"); + } else { + panic!("Expected text content block"); + } + + // Check tool result content + if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &anthropic_response.content[1] { + assert_eq!(tool_use_id, "tool_67890"); + if let ToolResultContent::Blocks(blocks) = content { + assert_eq!(blocks.len(), 1); + if let MessagesContentBlock::Text { text, .. } = &blocks[0] { + assert_eq!(text, "Temperature: 72°F, Sunny"); + } else { + panic!("Expected text content in tool result"); + } + } else { + panic!("Expected blocks in tool result content"); + } + } else { + panic!("Expected tool result content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_mixed_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I can help with multiple tasks.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_1".to_string(), + name: "search".to_string(), + input: json!({"query": "weather"}), + }, + }, + ContentBlock::Text { + text: "Let me also check another source.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_2".to_string(), + name: "lookup".to_string(), + input: json!({"id": "12345"}), + }, + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 25, + output_tokens: 40, + total_tokens: 65, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.content.len(), 4); + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::ToolUse); + + // Verify the sequence: text -> tool_use -> text -> tool_use + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "I can help with multiple tasks."); + } else { + panic!("Expected first content to be text"); + } + + if let MessagesContentBlock::ToolUse { id, name, .. } = &anthropic_response.content[1] { + assert_eq!(id, "tool_1"); + assert_eq!(name, "search"); + } else { + panic!("Expected second content to be tool use"); + } + + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[2] { + assert_eq!(text, "Let me also check another source."); + } else { + panic!("Expected third content to be text"); + } + + if let MessagesContentBlock::ToolUse { id, name, .. } = &anthropic_response.content[3] { + assert_eq!(id, "tool_2"); + assert_eq!(name, "lookup"); + } else { + panic!("Expected fourth content to be tool use"); + } + } + + #[test] + fn test_convert_bedrock_message_to_anthropic_content() { + let bedrock_message = BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello world!".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "test_tool".to_string(), + name: "test_function".to_string(), + input: json!({"param": "value"}), + }, + } + ], + }; + + let content_blocks = convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap(); + + assert_eq!(content_blocks.len(), 2); + + if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] { + assert_eq!(text, "Hello world!"); + } else { + panic!("Expected text content block"); + } + + if let MessagesContentBlock::ToolUse { id, name, input, .. } = &content_blocks[1] { + assert_eq!(id, "test_tool"); + assert_eq!(name, "test_function"); + assert_eq!(input["param"], "value"); + } else { + panic!("Expected tool use content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_role_conversion() { + // Test Assistant role + let assistant_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I am an assistant".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = assistant_response.try_into().unwrap(); + assert_eq!(anthropic_response.role, MessagesRole::Assistant); + + // Test User role + let user_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::User, + content: vec![ + ContentBlock::Text { + text: "I am a user".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = user_response.try_into().unwrap(); + assert_eq!(anthropic_response.role, MessagesRole::User); + } + + #[test] + fn test_bedrock_to_anthropic_model_extraction() { + // Test model extraction from trace information + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: Some(ConverseTrace { + guardrail: None, + prompt_router: Some(PromptRouterTrace { + invoked_model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), + }), + }), + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + // Should extract model ID from trace + assert_eq!(anthropic_response.model, "anthropic.claude-3-sonnet-20240229-v1:0"); + + // Test fallback when no trace information is available + let bedrock_response_no_trace = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response_fallback: MessagesResponse = bedrock_response_no_trace.try_into().unwrap(); + + // Should use fallback model name + assert_eq!(anthropic_response_fallback.model, "bedrock-model"); + } +} diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs new file mode 100644 index 00000000..53820708 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -0,0 +1,963 @@ +use crate::apis::openai::{ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, ResponseMessage, Role, ToolCallDelta, FunctionCallDelta, Usage, StreamChoice, MessageDelta, MessageContent}; +use crate::apis::anthropic::{MessagesResponse, MessagesStreamEvent, MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesUsage}; +use crate::apis::amazon_bedrock::{ConverseResponse, ConverseOutput, StopReason}; +use crate::clients::TransformError; +use crate::transforms::lib::*; + + +// ============================================================================ +// MAIN RESPONSE TRANSFORMATIONS +// ============================================================================ + +// Usage Conversions +impl Into for MessagesUsage { + fn into(self) -> Usage { + Usage { + prompt_tokens: self.input_tokens, + completion_tokens: self.output_tokens, + total_tokens: self.input_tokens + self.output_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + } + } +} + +impl TryFrom for ChatCompletionsResponse { + type Error = TransformError; + + fn try_from(resp: MessagesResponse) -> Result { + let content = convert_anthropic_content_to_openai(&resp.content)?; + let finish_reason: FinishReason = resp.stop_reason.into(); + let tool_calls = resp.content.extract_tool_calls()?; + + // Convert MessageContent to String for response + let content_string = match content { + MessageContent::Text(text) => Some(text), + MessageContent::Parts(parts) => { + let text = parts.extract_text(); + if text.is_empty() { None } else { Some(text) } + } + }; + + let message = ResponseMessage { + role: Role::Assistant, + content: content_string, + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls, + }; + + let choice = Choice { + index: 0, + message, + finish_reason: Some(finish_reason), + logprobs: None, + }; + + let usage = Usage { + prompt_tokens: resp.usage.input_tokens, + completion_tokens: resp.usage.output_tokens, + total_tokens: resp.usage.input_tokens + resp.usage.output_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(ChatCompletionsResponse { + id: resp.id, + object: Some("chat.completion".to_string()), + created: current_timestamp(), + model: resp.model, + choices: vec![choice], + usage, + system_fingerprint: None, + service_tier: None, + }) + } +} + +impl TryFrom for ChatCompletionsResponse { + type Error = TransformError; + + fn try_from(resp: ConverseResponse) -> Result { + // Extract the message from the ConverseOutput + let message = match resp.output { + ConverseOutput::Message { message } => message, + }; + + // Convert Bedrock ConversationRole to OpenAI Role + let role = match message.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + // Convert Bedrock message content to OpenAI format + let (content, tool_calls) = convert_bedrock_message_to_openai(&message)?; + + // Convert Bedrock stop reason to OpenAI finish reason + let finish_reason = match resp.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + + // Create response message + let response_message = ResponseMessage { + role, + content, + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls, + }; + + // Create choice + let choice = Choice { + index: 0, + message: response_message, + finish_reason: Some(finish_reason), + logprobs: None, + }; + + // Convert token usage + let usage = Usage { + prompt_tokens: resp.usage.input_tokens, + completion_tokens: resp.usage.output_tokens, + total_tokens: resp.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + // Generate a response ID (using timestamp since Bedrock doesn't provide one) + let id = format!("bedrock-{}", std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos()); + + // Extract model ID from trace information if available, otherwise use fallback + let model = resp.trace + .as_ref() + .and_then(|trace| trace.prompt_router.as_ref()) + .map(|router| router.invoked_model_id.clone()) + .unwrap_or_else(|| "bedrock-model".to_string()); + + Ok(ChatCompletionsResponse { + id, + object: Some("chat.completion".to_string()), + created: current_timestamp(), + model, + choices: vec![choice], + usage, + system_fingerprint: None, + service_tier: None, + }) + } +} + + +// ============================================================================ +// STREAMING TRANSFORMATIONS +// ============================================================================ + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: MessagesStreamEvent) -> Result { + match event { + MessagesStreamEvent::MessageStart { message } => { + Ok(create_openai_chunk( + &message.id, + &message.model, + MessageDelta { + role: Some(Role::Assistant), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + + MessagesStreamEvent::ContentBlockStart { content_block, .. } => { + convert_content_block_start(content_block) + } + + MessagesStreamEvent::ContentBlockDelta { delta, .. } => { + convert_content_delta(delta) + } + + MessagesStreamEvent::ContentBlockStop { .. } => { + Ok(create_empty_openai_chunk()) + } + + MessagesStreamEvent::MessageDelta { delta, usage } => { + let finish_reason: Option = Some(delta.stop_reason.into()); + let openai_usage: Option = Some(usage.into()); + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + finish_reason, + openai_usage, + )) + } + + MessagesStreamEvent::MessageStop => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(FinishReason::Stop), + None, + )) + } + + MessagesStreamEvent::Ping => { + Ok(ChatCompletionsStreamResponse { + id: "stream".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: "unknown".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }) + } + } + } +} + + +/// Convert content block start to OpenAI chunk +fn convert_content_block_start(content_block: MessagesContentBlock) -> Result { + match content_block { + MessagesContentBlock::Text { .. } => { + // No immediate output for text block start + Ok(create_empty_openai_chunk()) + } + MessagesContentBlock::ToolUse { id, name, .. } | + MessagesContentBlock::ServerToolUse { id, name, .. } | + MessagesContentBlock::McpToolUse { id, name, .. } => { + // Tool use start → OpenAI chunk with tool_calls + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: Some(id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )) + } + _ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())), + } +} + +/// Convert content delta to OpenAI chunk +fn convert_content_delta(delta: MessagesContentDelta) -> Result { + match delta { + MessagesContentDelta::TextDelta { text } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + MessagesContentDelta::ThinkingDelta { thinking } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(format!("thinking: {}", thinking)), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + MessagesContentDelta::InputJsonDelta { partial_json } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(partial_json), + }), + }]), + }, + None, + None, + )) + } + } +} + +/// Helper to create OpenAI streaming chunk +fn create_openai_chunk( + id: &str, + model: &str, + delta: MessageDelta, + finish_reason: Option, + usage: Option +) -> ChatCompletionsStreamResponse { + ChatCompletionsStreamResponse { + id: id.to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: model.to_string(), + choices: vec![StreamChoice { + index: 0, + delta, + finish_reason, + logprobs: None, + }], + usage, + system_fingerprint: None, + service_tier: None, + } +} + +/// Helper to create empty OpenAI streaming chunk +fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { + create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + ) +} + +/// Convert Anthropic content blocks to OpenAI message content +fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result { + let mut text_parts = Vec::new(); + + for block in content { + match block { + MessagesContentBlock::Text { text, .. } => { + text_parts.push(text.clone()); + } + MessagesContentBlock::Thinking { thinking, .. } => { + text_parts.push(format!("thinking: {}", thinking)); + } + _ => { + // Skip other content types for basic text conversion + continue; + } + } + } + + Ok(MessageContent::Text(text_parts.join("\n"))) +} + +// Stop Reason Conversions +impl Into for MessagesStopReason { + fn into(self) -> FinishReason { + match self { + MessagesStopReason::EndTurn => FinishReason::Stop, + MessagesStopReason::MaxTokens => FinishReason::Length, + MessagesStopReason::StopSequence => FinishReason::Stop, + MessagesStopReason::ToolUse => FinishReason::ToolCalls, + MessagesStopReason::PauseTurn => FinishReason::Stop, + MessagesStopReason::Refusal => FinishReason::ContentFilter, + } + } +} + +/// Convert Bedrock Message to OpenAI content and tool calls +/// This function extracts text content and tool calls from a Bedrock message +fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Message) -> Result<(Option, Option>), TransformError> { + use crate::apis::amazon_bedrock::ContentBlock; + use crate::apis::openai::{ToolCall, FunctionCall}; + + let mut text_content = String::new(); + let mut tool_calls = Vec::new(); + + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + text_content.push_str(text); + } + ContentBlock::ToolUse { tool_use } => { + tool_calls.push(ToolCall { + id: tool_use.tool_use_id.clone(), + call_type: "function".to_string(), + function: FunctionCall { + name: tool_use.name.clone(), + arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(), + }, + }); + } + _ => continue, + } + } + + let content = if text_content.is_empty() { None } else { Some(text_content) }; + let tool_calls = if tool_calls.is_empty() { None } else { Some(tool_calls) }; + + Ok((content, tool_calls)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole, + ContentBlock, StopReason, BedrockTokenUsage, ConverseTrace, PromptRouterTrace + }; + use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role}; + use serde_json::json; + + #[test] + fn test_bedrock_to_openai_basic_response() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello! How can I help you today?".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 25, + total_tokens: 35, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(openai_response.object, Some("chat.completion".to_string())); + assert_eq!(openai_response.model, "bedrock-model"); + assert!(openai_response.id.starts_with("bedrock-")); + + // Check choices + assert_eq!(openai_response.choices.len(), 1); + let choice = &openai_response.choices[0]; + assert_eq!(choice.index, 0); + assert_eq!(choice.message.role, Role::Assistant); + assert_eq!(choice.message.content, Some("Hello! How can I help you today?".to_string())); + assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); + assert!(choice.message.tool_calls.is_none()); + + // Check usage + assert_eq!(openai_response.usage.prompt_tokens, 10); + assert_eq!(openai_response.usage.completion_tokens, 25); + assert_eq!(openai_response.usage.total_tokens, 35); + } + + #[test] + fn test_bedrock_to_openai_with_tool_use() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help you check the weather.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_12345".to_string(), + name: "get_weather".to_string(), + input: json!({ + "location": "San Francisco" + }), + }, + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 15, + output_tokens: 30, + total_tokens: 45, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls)); + assert_eq!(openai_response.choices[0].message.content, Some("I'll help you check the weather.".to_string())); + + // Check tool calls + let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.id, "tool_12345"); + assert_eq!(tool_call.call_type, "function"); + assert_eq!(tool_call.function.name, "get_weather"); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap(); + assert_eq!(args["location"], "San Francisco"); + } + + #[test] + fn test_bedrock_to_openai_stop_reason_conversions() { + let test_cases = vec![ + (StopReason::EndTurn, FinishReason::Stop), + (StopReason::ToolUse, FinishReason::ToolCalls), + (StopReason::MaxTokens, FinishReason::Length), + (StopReason::StopSequence, FinishReason::Stop), + (StopReason::GuardrailIntervened, FinishReason::ContentFilter), + (StopReason::ContentFiltered, FinishReason::ContentFilter), + ]; + + for (bedrock_stop_reason, expected_openai_finish_reason) in test_cases { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Test response".to_string(), + } + ], + }, + }, + stop_reason: bedrock_stop_reason, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + assert_eq!(openai_response.choices[0].finish_reason, Some(expected_openai_finish_reason)); + } + } + + #[test] + fn test_bedrock_to_openai_multiple_tool_calls() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help with multiple tasks.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_1".to_string(), + name: "search".to_string(), + input: json!({"query": "weather"}), + }, + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_2".to_string(), + name: "lookup".to_string(), + input: json!({"id": "12345"}), + }, + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 25, + output_tokens: 40, + total_tokens: 65, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls)); + assert_eq!(openai_response.choices[0].message.content, Some("I'll help with multiple tasks.".to_string())); + + // Check multiple tool calls + let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + + // First tool call + assert_eq!(tool_calls[0].id, "tool_1"); + assert_eq!(tool_calls[0].function.name, "search"); + let args1: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args1["query"], "weather"); + + // Second tool call + assert_eq!(tool_calls[1].id, "tool_2"); + assert_eq!(tool_calls[1].function.name, "lookup"); + let args2: serde_json::Value = serde_json::from_str(&tool_calls[1].function.arguments).unwrap(); + assert_eq!(args2["id"], "12345"); + } + + #[test] + fn test_bedrock_to_openai_mixed_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "First part. ".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_mid".to_string(), + name: "calculate".to_string(), + input: json!({"expr": "2+2"}), + }, + }, + ContentBlock::Text { + text: "Second part.".to_string(), + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 20, + output_tokens: 35, + total_tokens: 55, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Content should be combined text parts (no separator added) + assert_eq!(openai_response.choices[0].message.content, Some("First part. Second part.".to_string())); + + // Should have one tool call + let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "tool_mid"); + assert_eq!(tool_calls[0].function.name, "calculate"); + } + + #[test] + fn test_bedrock_to_openai_empty_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_only".to_string(), + name: "action".to_string(), + input: json!({}), + }, + } + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Content should be None when there's no text + assert_eq!(openai_response.choices[0].message.content, None); + + // Should have tool call + let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "tool_only"); + } + + #[test] + fn test_convert_bedrock_message_to_openai() { + let bedrock_message = BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello world!".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "test_tool".to_string(), + name: "test_function".to_string(), + input: json!({"param": "value"}), + }, + } + ], + }; + + let (content, tool_calls) = convert_bedrock_message_to_openai(&bedrock_message).unwrap(); + + assert_eq!(content, Some("Hello world!".to_string())); + + let tool_calls = tool_calls.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "test_tool"); + assert_eq!(tool_calls[0].function.name, "test_function"); + + let args: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args["param"], "value"); + } + + #[test] + fn test_bedrock_to_openai_role_conversion() { + // Test Assistant role + let assistant_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I am an assistant".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = assistant_response.try_into().unwrap(); + assert_eq!(openai_response.choices[0].message.role, Role::Assistant); + + // Test User role + let user_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::User, + content: vec![ + ContentBlock::Text { + text: "I am a user".to_string(), + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = user_response.try_into().unwrap(); + assert_eq!(openai_response.choices[0].message.role, Role::User); + } + + #[test] + fn test_bedrock_to_openai_model_extraction() { + // Test model extraction from trace information + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: Some(ConverseTrace { + guardrail: None, + prompt_router: Some(PromptRouterTrace { + invoked_model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), + }), + }), + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Should extract model ID from trace + assert_eq!(openai_response.model, "anthropic.claude-3-sonnet-20240229-v1:0"); + + // Test fallback when no trace information is available + let bedrock_response_no_trace = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response_fallback: ChatCompletionsResponse = bedrock_response_no_trace.try_into().unwrap(); + + // Should use fallback model name + assert_eq!(openai_response_fallback.model, "bedrock-model"); + } + + #[test] + fn test_bedrock_to_openai_with_multimedia_content() { + use crate::apis::amazon_bedrock::{ImageSource}; + + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Here's the analysis:".to_string(), + }, + ContentBlock::Image { + image: crate::apis::amazon_bedrock::ImageBlock { + source: ImageSource::Base64 { + media_type: "image/jpeg".to_string(), + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(), + }, + }, + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 50, + output_tokens: 75, + total_tokens: 125, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::Stop)); + + let content = openai_response.choices[0].message.content.as_ref().unwrap(); + + // Check that text content is preserved (image blocks are currently ignored) + assert!(content.contains("Here's the analysis:")); + // Note: Image blocks are not converted to text in the current implementation + } +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index f11134cb..8ade853a 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,3 +1,4 @@ +use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -12,8 +13,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::metrics::Metrics; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ - ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, - REQUEST_ID_HEADER, TRACE_PARENT_HEADER, + ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, + RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; use common::errors::ServerError; use common::llm_providers::LlmProviders; @@ -33,7 +34,7 @@ pub struct StreamContext { /// The API that is requested by the client (before compatibility mapping) client_api: Option, /// The API that should be used for the upstream provider (after compatibility mapping) - resolved_api: Option, + resolved_api: Option, llm_providers: Rc, llm_provider: Option>, request_id: Option, @@ -108,6 +109,7 @@ impl StreamContext { .model .as_ref() .unwrap_or(&"".to_string()), + self.streaming_response, ); if target_endpoint != request_path { self.set_http_request_header(":path", Some(&target_endpoint)); @@ -148,14 +150,19 @@ impl StreamContext { // Set API-specific headers based on the resolved upstream API match self.resolved_api.as_ref() { - Some(SupportedAPIs::AnthropicMessagesAPI(_)) => { + Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { // Anthropic API requires x-api-key and anthropic-version headers // Remove any existing Authorization header since Anthropic doesn't use it self.remove_http_request_header("Authorization"); self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value)); self.set_http_request_header("anthropic-version", Some("2023-06-01")); } - Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => { + Some( + SupportedUpstreamAPIs::OpenAIChatCompletions(_) + | SupportedUpstreamAPIs::AmazonBedrockConverse(_) + | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) + | None => { // OpenAI and default: use Authorization Bearer token // Remove any existing x-api-key header since OpenAI doesn't use it self.remove_http_request_header("x-api-key"); @@ -410,7 +417,8 @@ impl StreamContext { match self.client_api.as_ref() { Some(client_api) => { let client_api = client_api.clone(); // Clone to avoid borrowing issues - let upstream_api = provider_id.compatible_api_for_client(&client_api); + let upstream_api = + provider_id.compatible_api_for_client(&client_api, self.streaming_response); // Parse body into SSE iterator using TryFrom let sse_iter: SseStreamIter> = @@ -578,6 +586,11 @@ impl HttpContext for StreamContext { return Action::Continue; } + self.streaming_response = self + .get_http_request_header(ARCH_IS_STREAMING_HEADER) + .map(|val| val == "true") + .unwrap_or(false); + let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, @@ -612,7 +625,17 @@ impl HttpContext for StreamContext { (self.client_api.as_ref(), self.llm_provider.as_ref()) { let provider_id = provider.to_provider_id(); - self.resolved_api = Some(provider_id.compatible_api_for_client(api)); + self.resolved_api = + Some(provider_id.compatible_api_for_client(api, self.streaming_response)); + + debug!( + "[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'", + self.request_identifier(), + provider.to_provider_id(), + api, + self.resolved_api, + request_path + ); } else { self.resolved_api = None; } @@ -697,7 +720,7 @@ impl HttpContext for StreamContext { //We need to deserialize the request body based on the resolved API let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() { Some(the_client_api) => { - debug!( + info!( "[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}", self.request_identifier(), the_client_api, @@ -795,7 +818,10 @@ impl HttpContext for StreamContext { ); // Use provider interface for streaming detection and setup - self.streaming_response = deserialized_client_request.is_streaming(); + // If streaming_response is not already set from headers, get it from the parsed request + if !self.streaming_response { + self.streaming_response = deserialized_client_request.is_streaming(); + } // Use provider interface for text extraction (after potential mutation) let input_tokens_str = deserialized_client_request.extract_messages_text(); diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index 794ed117..0a626be9 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -35,6 +35,10 @@ llm_providers: access_key: $AZURE_API_KEY base_url: https://katanemo.openai.azure.com + - model: amazon_bedrock/us.amazon.nova-premier-v1:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com + # Ollama Models - model: ollama/llama3.1 base_url: http://host.docker.internal:11434 @@ -71,3 +75,6 @@ model_aliases: creative-model: target: claude-sonnet-4-20250514 + + coding-model: + target: us.amazon.nova-premier-v1:0 diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index d5a289a6..6539deb6 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -403,3 +403,99 @@ def test_anthropic_thinking_mode_streaming(): final_block_types = [blk.type for blk in final.content] assert "text" in final_block_types assert "thinking" in final_block_types + + +def test_openai_client_with_coding_model_alias_and_tools(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with coding question and tools""" + logger.info("Testing OpenAI client with 'coding-model' alias -> Bedrock with tools") + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + completion = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + ], + tool_choice="auto", + ) + + response_content = completion.choices[0].message.content + tool_calls = completion.choices[0].message.tool_calls + # Should get either text response or tool calls for coding assistance + assert response_content is not None or ( + tool_calls is not None and len(tool_calls) > 0 + ) + + +def test_anthropic_client_with_coding_model_alias_and_tools(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock with tools" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + message = client.messages.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "input_schema": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + } + ], + tool_choice={"type": "auto"}, + ) + + text_content = "".join(b.text for b in message.content if b.type == "text") + tool_use_blocks = [b for b in message.content if b.type == "tool_use"] + + logger.info(f"Response from coding-model alias via Anthropic: {text_content}") + logger.info(f"Tool use blocks: {len(tool_use_blocks)}") + + # Should get either text response or tool use blocks for coding assistance + assert text_content or len(tool_use_blocks) > 0