Add support for Amazon Bedrock Converse and ConverseStream (#588)

* first commit to get Bedrock Converse API working. Next commit support for streaming and binary frames

* adding translation from BedrockBinaryFrameDecoder to AnthropicMessagesEvent

* Claude Code works with Amazon Bedrock

* added tests for openai streaming from bedrock

* PR comments fixed

* adding support for bedrock in docs as supported provider

* cargo fmt

* revertted to chatgpt models for claude code routing

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local>
Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
This commit is contained in:
Salman Paracha 2025-10-22 11:31:21 -07:00 committed by GitHub
parent ba826b1961
commit 9407ae6af7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 7362 additions and 1493 deletions

View file

@ -1,3 +1,4 @@
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -12,8 +13,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::metrics::Metrics;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY,
REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
@ -21,9 +22,15 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent};
use hermesllm::apis::sse::{SseEvent, SseStreamIter};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::{ProviderResponse, SseEvent, SseStreamIter};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
use hermesllm::providers::response::ProviderResponse;
use hermesllm::{
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
ProviderStreamResponseType,
};
pub struct StreamContext {
metrics: Rc<Metrics>,
@ -33,7 +40,7 @@ pub struct StreamContext {
/// The API that is requested by the client (before compatibility mapping)
client_api: Option<SupportedAPIs>,
/// The API that should be used for the upstream provider (after compatibility mapping)
resolved_api: Option<SupportedAPIs>,
resolved_api: Option<SupportedUpstreamAPIs>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
@ -45,8 +52,8 @@ pub struct StreamContext {
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
user_message: Option<String>,
/// Store upstream response status code to handle error responses gracefully
upstream_status_code: Option<StatusCode>,
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
}
impl StreamContext {
@ -75,6 +82,7 @@ impl StreamContext {
request_body_sent_time: None,
user_message: None,
upstream_status_code: None,
binary_frame_decoder: None,
}
}
@ -108,6 +116,7 @@ impl StreamContext {
.model
.as_ref()
.unwrap_or(&"".to_string()),
self.streaming_response,
);
if target_endpoint != request_path {
self.set_http_request_header(":path", Some(&target_endpoint));
@ -148,14 +157,19 @@ impl StreamContext {
// Set API-specific headers based on the resolved upstream API
match self.resolved_api.as_ref() {
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {
Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
// Anthropic API requires x-api-key and anthropic-version headers
// Remove any existing Authorization header since Anthropic doesn't use it
self.remove_http_request_header("Authorization");
self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value));
self.set_http_request_header("anthropic-version", Some("2023-06-01"));
}
Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => {
Some(
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
)
| None => {
// OpenAI and default: use Authorization Bearer token
// Remove any existing x-api-key header since OpenAI doesn't use it
self.remove_http_request_header("x-api-key");
@ -410,7 +424,16 @@ impl StreamContext {
match self.client_api.as_ref() {
Some(client_api) => {
let client_api = client_api.clone(); // Clone to avoid borrowing issues
let upstream_api = provider_id.compatible_api_for_client(&client_api);
let upstream_api =
provider_id.compatible_api_for_client(&client_api, self.streaming_response);
// Check if this is Bedrock binary stream
if matches!(
upstream_api,
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
) {
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
}
// Parse body into SSE iterator using TryFrom
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
@ -487,6 +510,127 @@ impl StreamContext {
}
}
fn handle_bedrock_binary_stream(
&mut self,
body: &[u8],
client_api: &SupportedAPIs,
upstream_api: &SupportedUpstreamAPIs,
) -> Result<Vec<u8>, Action> {
// Initialize decoder if not present
if self.binary_frame_decoder.is_none() {
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
}
// Add incoming bytes to buffer
let decoder = self.binary_frame_decoder.as_mut().unwrap();
decoder.buffer_mut().extend_from_slice(body);
let mut response_buffer = Vec::new();
loop {
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
match decoded_frame {
Some(DecodedFrame::Complete(ref frame_ref)) => {
let frame = DecodedFrame::Complete(frame_ref.clone());
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
Ok(provider_response) => {
self.record_ttft_if_needed();
// Handle ContentBlockStart and ContentBlockDelta events
match &provider_response {
ProviderStreamResponseType::MessagesStreamEvent(evt) => {
match evt {
MessagesStreamEvent::ContentBlockStart {
index, ..
} => {
// Mark that we've seen ContentBlockStart for this index
self.binary_frame_decoder
.as_mut()
.unwrap()
.set_content_block_start_sent(*index as i32);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
self.request_identifier(),
*index
);
}
MessagesStreamEvent::ContentBlockDelta {
index, ..
} => {
// Check if ContentBlockStart was sent for this index
let needs_start = !self
.binary_frame_decoder
.as_ref()
.unwrap()
.has_content_block_start_been_sent(*index as i32);
if needs_start {
// Emit empty ContentBlockStart before delta
let content_block_start =
MessagesStreamEvent::ContentBlockStart {
index: *index,
content_block: MessagesContentBlock::Text {
text: String::new(),
cache_control: None,
},
};
let start_sse: String = content_block_start.into();
response_buffer
.extend_from_slice(start_sse.as_bytes());
// Mark that we've now sent it
self.binary_frame_decoder
.as_mut()
.unwrap()
.set_content_block_start_sent(*index as i32);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
self.request_identifier(),
*index
);
}
}
_ => {}
}
}
_ => {}
}
let sse_string: String = provider_response.into();
response_buffer.extend_from_slice(sse_string.as_bytes());
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
self.request_identifier(),
e
);
}
}
}
Some(DecodedFrame::Incomplete) => {
// Incomplete frame - buffer retains partial data, wait for more bytes
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
self.request_identifier()
);
break;
}
None => {
// Decode error
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
self.request_identifier()
);
return Err(Action::Continue);
}
}
}
// Return accumulated complete frames (may be empty if all frames incomplete)
Ok(response_buffer)
}
fn handle_non_streaming_response(
&mut self,
body: &[u8],
@ -578,6 +722,11 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.streaming_response = self
.get_http_request_header(ARCH_IS_STREAMING_HEADER)
.map(|val| val == "true")
.unwrap_or(false);
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
@ -612,7 +761,17 @@ impl HttpContext for StreamContext {
(self.client_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
self.resolved_api =
Some(provider_id.compatible_api_for_client(api, self.streaming_response));
debug!(
"[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
self.request_identifier(),
provider.to_provider_id(),
api,
self.resolved_api,
request_path
);
} else {
self.resolved_api = None;
}
@ -697,7 +856,7 @@ impl HttpContext for StreamContext {
//We need to deserialize the request body based on the resolved API
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
Some(the_client_api) => {
debug!(
info!(
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
self.request_identifier(),
the_client_api,
@ -795,7 +954,10 @@ impl HttpContext for StreamContext {
);
// Use provider interface for streaming detection and setup
self.streaming_response = deserialized_client_request.is_streaming();
// If streaming_response is not already set from headers, get it from the parsed request
if !self.streaming_response {
self.streaming_response = deserialized_client_request.is_streaming();
}
// Use provider interface for text extraction (after potential mutation)
let input_tokens_str = deserialized_client_request.extract_messages_text();