mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 23:32:43 +02:00
Add support for Amazon Bedrock Converse and ConverseStream (#588)
* first commit to get Bedrock Converse API working. Next commit support for streaming and binary frames * adding translation from BedrockBinaryFrameDecoder to AnthropicMessagesEvent * Claude Code works with Amazon Bedrock * added tests for openai streaming from bedrock * PR comments fixed * adding support for bedrock in docs as supported provider * cargo fmt * revertted to chatgpt models for claude code routing --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local> Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
This commit is contained in:
parent
ba826b1961
commit
9407ae6af7
35 changed files with 7362 additions and 1493 deletions
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::{AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
|
|
@ -19,7 +19,8 @@ pub enum ProviderId {
|
|||
Ollama,
|
||||
Moonshotai,
|
||||
Zhipu,
|
||||
Qwen, // alias for Qwen
|
||||
Qwen,
|
||||
AmazonBedrock,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
|
|
@ -39,7 +40,8 @@ impl From<&str> for ProviderId {
|
|||
"ollama" => ProviderId::Ollama,
|
||||
"moonshotai" => ProviderId::Moonshotai,
|
||||
"zhipu" => ProviderId::Zhipu,
|
||||
"qwen" => ProviderId::Qwen, // alias for Zhipu
|
||||
"qwen" => ProviderId::Qwen, // alias for Qwen
|
||||
"amazon_bedrock" => ProviderId::AmazonBedrock,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
|
|
@ -47,16 +49,20 @@ impl From<&str> for ProviderId {
|
|||
|
||||
impl ProviderId {
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
client_api: &SupportedAPIs,
|
||||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(
|
||||
|
|
@ -75,7 +81,7 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
|
|
@ -93,7 +99,27 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// Amazon Bedrock natively supports Bedrock APIs
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -116,6 +142,7 @@ impl Display for ProviderId {
|
|||
ProviderId::Moonshotai => write!(f, "moonshotai"),
|
||||
ProviderId::Zhipu => write!(f, "zhipu"),
|
||||
ProviderId::Qwen => write!(f, "qwen"),
|
||||
ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -10,6 +13,8 @@ use std::fmt;
|
|||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
BedrockConverse(ConverseRequest),
|
||||
BedrockConverseStream(ConverseStreamRequest),
|
||||
//add more request types here
|
||||
}
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
|
|
@ -42,6 +47,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.model(),
|
||||
Self::MessagesRequest(r) => r.model(),
|
||||
Self::BedrockConverse(r) => r.model(),
|
||||
Self::BedrockConverseStream(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +56,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_model(model),
|
||||
Self::MessagesRequest(r) => r.set_model(model),
|
||||
Self::BedrockConverse(r) => r.set_model(model),
|
||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -56,6 +65,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.is_streaming(),
|
||||
Self::MessagesRequest(r) => r.is_streaming(),
|
||||
Self::BedrockConverse(_) => false,
|
||||
Self::BedrockConverseStream(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,6 +74,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||
Self::MessagesRequest(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -70,6 +83,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
|
||||
Self::MessagesRequest(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -77,6 +92,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.to_bytes(),
|
||||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
Self::BedrockConverse(r) => r.to_bytes(),
|
||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,6 +101,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.metadata(),
|
||||
Self::MessagesRequest(r) => r.metadata(),
|
||||
Self::BedrockConverse(r) => r.metadata(),
|
||||
Self::BedrockConverseStream(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,6 +110,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
|
||||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,27 +141,27 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
}
|
||||
|
||||
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
|
||||
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from(
|
||||
(request, upstream_api): (ProviderRequestType, &SupportedAPIs),
|
||||
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (request, upstream_api) {
|
||||
match (client_request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let messages_req =
|
||||
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
|
||||
|
|
@ -155,7 +176,7 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
|||
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
|
|
@ -168,6 +189,69 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
|||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req =
|
||||
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
// Amazon Bedrock to other APIs conversions
|
||||
(ProviderRequestType::BedrockConverse(_), _) => {
|
||||
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
|
||||
}
|
||||
|
||||
(ProviderRequestType::BedrockConverseStream(_), _) => {
|
||||
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -201,7 +285,7 @@ mod tests {
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue