mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
Add support for Amazon Bedrock Converse and ConverseStream (#588)
* first commit to get Bedrock Converse API working. Next commit support for streaming and binary frames * adding translation from BedrockBinaryFrameDecoder to AnthropicMessagesEvent * Claude Code works with Amazon Bedrock * added tests for openai streaming from bedrock * PR comments fixed * adding support for bedrock in docs as supported provider * cargo fmt * revertted to chatgpt models for claude code routing --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local> Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
This commit is contained in:
parent
ba826b1961
commit
9407ae6af7
35 changed files with 7362 additions and 1493 deletions
2
.github/workflows/e2e_archgw.yml
vendored
2
.github/workflows/e2e_archgw.yml
vendored
|
|
@ -39,6 +39,8 @@ jobs:
|
|||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }}
|
||||
|
||||
run: |
|
||||
docker compose up | tee &> archgw.logs &
|
||||
|
||||
|
|
|
|||
1
.github/workflows/e2e_tests.yml
vendored
1
.github/workflows/e2e_tests.yml
vendored
|
|
@ -32,6 +32,7 @@ jobs:
|
|||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }}
|
||||
run: |
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ SUPPORTED_PROVIDERS = [
|
|||
"moonshotai",
|
||||
"zhipu",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -188,7 +189,10 @@ def validate_and_render_schema():
|
|||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (
|
||||
provider == "azure_openai" or provider == "ollama" or provider == "qwen"
|
||||
provider == "azure_openai"
|
||||
or provider == "ollama"
|
||||
or provider == "qwen"
|
||||
or provider == "amazon_bedrock"
|
||||
) and model_provider.get("base_url") is None:
|
||||
raise Exception(
|
||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||
|
|
|
|||
86
crates/Cargo.lock
generated
86
crates/Cargo.lock
generated
|
|
@ -101,6 +101,35 @@ version = "1.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "aws-smithy-eventstream"
|
||||
version = "0.60.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9656b85088f8d9dc7ad40f9a6c7228e1e8447cdf4b046c87e152e0805dea02fa"
|
||||
dependencies = [
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
"crc32fast",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-smithy-types"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f5b3a7486f6690ba25952cabf1e7d75e34d69eaff5081904a47bc79074d6457"
|
||||
dependencies = [
|
||||
"base64-simd",
|
||||
"bytes",
|
||||
"bytes-utils",
|
||||
"itoa",
|
||||
"num-integer",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"ryu",
|
||||
"serde",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.75"
|
||||
|
|
@ -128,6 +157,16 @@ version = "0.22.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "base64-simd"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195"
|
||||
dependencies = [
|
||||
"outref",
|
||||
"vsimd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
|
@ -217,6 +256,16 @@ version = "1.10.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
|
||||
|
||||
[[package]]
|
||||
name = "bytes-utils"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.26"
|
||||
|
|
@ -302,6 +351,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
|
|
@ -738,6 +796,8 @@ dependencies = [
|
|||
name = "hermesllm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"bytes",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
|
|
@ -1191,6 +1251,7 @@ name = "llm_gateway"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acap",
|
||||
"bytes",
|
||||
"common",
|
||||
"derivative",
|
||||
"governor",
|
||||
|
|
@ -1359,6 +1420,15 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
|
|
@ -1517,6 +1587,12 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "outref"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
|
|
@ -2081,9 +2157,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.14.0"
|
||||
version = "2.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
|
||||
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
|
|
@ -2812,6 +2888,12 @@ version = "0.9.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "vsimd"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
|
@ -56,6 +57,7 @@ pub async fn chat(
|
|||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
|
|
@ -84,10 +86,16 @@ pub async fn chat(
|
|||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::MessagesRequest(_)) => {
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_),
|
||||
) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
|
|
@ -190,6 +198,11 @@ pub async fn chat(
|
|||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
|
|
|
|||
|
|
@ -206,6 +206,8 @@ pub enum LlmProviderType {
|
|||
Zhipu,
|
||||
#[serde(rename = "qwen")]
|
||||
Qwen,
|
||||
#[serde(rename = "amazon_bedrock")]
|
||||
AmazonBedrock,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -225,6 +227,7 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::Moonshotai => write!(f, "moonshotai"),
|
||||
LlmProviderType::Zhipu => write!(f, "zhipu"),
|
||||
LlmProviderType::Qwen => write!(f, "qwen"),
|
||||
LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
|
|||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
|
|
|
|||
|
|
@ -6,5 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
serde = {version = "1.0.219", features = ["derive"]}
|
||||
serde_json = "1.0.140"
|
||||
serde_with = "3.12.0"
|
||||
serde_with = {version = "3.12.0", features = ["base64"]}
|
||||
thiserror = "2.0.12"
|
||||
aws-smithy-eventstream = "0.60"
|
||||
bytes = "1.10"
|
||||
|
|
|
|||
1149
crates/hermesllm/src/apis/amazon_bedrock.rs
Normal file
1149
crates/hermesllm/src/apis/amazon_bedrock.rs
Normal file
File diff suppressed because it is too large
Load diff
65
crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs
Normal file
65
crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
use aws_smithy_eventstream::frame::MessageFrameDecoder;
|
||||
use bytes::Buf;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// AWS Event Stream frame decoder wrapper
|
||||
pub struct BedrockBinaryFrameDecoder<B>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
decoder: MessageFrameDecoder,
|
||||
buffer: B,
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
}
|
||||
|
||||
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
||||
/// This is a convenience constructor that creates a BytesMut buffer internally
|
||||
pub fn from_bytes(bytes: &[u8]) -> Self {
|
||||
let buffer = bytes::BytesMut::from(bytes);
|
||||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> BedrockBinaryFrameDecoder<B>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
pub fn new(buffer: B) -> Self {
|
||||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
|
||||
match self.decoder.decode_frame(&mut self.buffer) {
|
||||
Ok(frame) => Some(frame),
|
||||
Err(_e) => None, // Fatal decode error
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_mut(&mut self) -> &mut B {
|
||||
&mut self.buffer
|
||||
}
|
||||
|
||||
/// Check if there are any bytes remaining in the buffer
|
||||
pub fn has_remaining(&self) -> bool {
|
||||
self.buffer.has_remaining()
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
pub fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
|
@ -5,9 +5,9 @@ use serde_with::skip_serializing_none;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
|
|
|
|||
|
|
@ -1,7 +1,19 @@
|
|||
pub mod amazon_bedrock;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub use anthropic::*;
|
||||
pub use openai::*;
|
||||
pub mod sse;
|
||||
|
||||
// Explicit exports to avoid naming conflicts
|
||||
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
|
||||
pub use amazon_bedrock::{
|
||||
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
||||
pub use openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
|
||||
};
|
||||
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
|
||||
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ use std::fmt::Display;
|
|||
use thiserror::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
|
||||
// ============================================================================
|
||||
|
|
|
|||
196
crates/hermesllm/src/apis/sse.rs
Normal file
196
crates/hermesllm/src/apis/sse.rs
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::response::ProviderStreamResponseType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
// ============================================================================
|
||||
|
||||
/// Represents a single Server-Sent Event with the complete wire format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SseEvent {
|
||||
#[serde(rename = "data")]
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == Some("[DONE]".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
/// This includes ping messages and other provider-specific events that don't contain content
|
||||
pub fn should_skip(&self) -> bool {
|
||||
// Skip ping messages (commonly used by providers for connection keep-alive)
|
||||
self.data == Some(r#"{"type": "ping"}"#.into())
|
||||
}
|
||||
|
||||
/// Check if this is an event-only SSE event (no data payload)
|
||||
pub fn is_event_only(&self) -> bool {
|
||||
self.event.is_some() && self.data.is_none()
|
||||
}
|
||||
|
||||
/// Get the parsed provider response if available
|
||||
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
|
||||
self.provider_stream_response
|
||||
.as_ref()
|
||||
.map(|resp| resp as &dyn ProviderStreamResponse)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SseEvent {
|
||||
type Err = SseParseError;
|
||||
|
||||
fn from_str(line: &str) -> Result<Self, Self::Err> {
|
||||
if line.starts_with("data: ") {
|
||||
let data: String = line[6..].to_string(); // Remove "data: " prefix
|
||||
if data.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.sse_transform_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SseParseError {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for SseParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "SSE parse error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SseParseError {}
|
||||
|
||||
/// Generic SSE (Server-Sent Events) streaming iterator container
|
||||
/// Parses raw SSE lines into SseEvent objects
|
||||
pub struct SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub lines: I,
|
||||
pub done_seen: bool,
|
||||
}
|
||||
|
||||
impl<I> SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self {
|
||||
lines,
|
||||
done_seen: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TryFrom implementation to parse bytes into SseStreamIter
|
||||
// Handles both text-based SSE and binary AWS Event Stream formats
|
||||
impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
// Parse as text-based SSE format
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
Ok(SseStreamIter::new(lines.into_iter()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = SseEvent;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// If we already returned [DONE], terminate the stream
|
||||
if self.done_seen {
|
||||
return None;
|
||||
}
|
||||
|
||||
for line in &mut self.lines {
|
||||
let line_str = line.as_ref();
|
||||
|
||||
// Try to parse as either data: or event: line
|
||||
if let Ok(event) = line_str.parse::<SseEvent>() {
|
||||
// For data: lines, check if this is the [DONE] marker
|
||||
if event.data.is_some() && event.is_done() {
|
||||
self.done_seen = true;
|
||||
return Some(event); // Return [DONE] event for transformation
|
||||
}
|
||||
// For data: lines, skip events that should be filtered at the transport layer
|
||||
if event.data.is_some() && event.should_skip() {
|
||||
continue;
|
||||
}
|
||||
return Some(event);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
|
@ -1,30 +1,5 @@
|
|||
//! Supported endpoint registry for LLM APIs
|
||||
//!
|
||||
//! This module provides a simple registry to check which API endpoint paths
|
||||
//! we support across different providers.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```rust
|
||||
//! use hermesllm::clients::endpoints::supported_endpoints;
|
||||
//!
|
||||
//! // Check if we support an endpoint
|
||||
//! use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
//! assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
//! assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
//! assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
//!
|
||||
//! // Get all supported endpoints
|
||||
//! let endpoints = supported_endpoints();
|
||||
//! assert_eq!(endpoints.len(), 2);
|
||||
//! assert!(endpoints.contains(&"/v1/chat/completions"));
|
||||
//! assert!(endpoints.contains(&"/v1/messages"));
|
||||
//! ```
|
||||
|
||||
use crate::{
|
||||
apis::{AnthropicApi, ApiDefinition, OpenAIApi},
|
||||
ProviderId,
|
||||
};
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
|
||||
use crate::ProviderId;
|
||||
use std::fmt;
|
||||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
|
|
@ -34,6 +9,14 @@ pub enum SupportedAPIs {
|
|||
AnthropicMessagesAPI(AnthropicApi),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SupportedUpstreamAPIs {
|
||||
OpenAIChatCompletions(OpenAIApi),
|
||||
AnthropicMessagesAPI(AnthropicApi),
|
||||
AmazonBedrockConverse(AmazonBedrockApi),
|
||||
AmazonBedrockConverseStream(AmazonBedrockApi),
|
||||
}
|
||||
|
||||
impl fmt::Display for SupportedAPIs {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
|
|
@ -74,11 +57,21 @@ impl SupportedAPIs {
|
|||
provider_id: &ProviderId,
|
||||
request_path: &str,
|
||||
model_id: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
format!("/model/{}/converse", model_id)
|
||||
} else if request_path.starts_with("/v1/") && is_streaming {
|
||||
format!("/model/{}/converse-stream", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
},
|
||||
_ => match provider_id {
|
||||
|
|
@ -117,6 +110,17 @@ impl SupportedAPIs {
|
|||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
format!("/model/{}/converse", model_id)
|
||||
} else {
|
||||
format!("/model/{}/converse-stream", model_id)
|
||||
}
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
},
|
||||
}
|
||||
|
|
@ -161,7 +165,6 @@ mod tests {
|
|||
fn test_is_supported_endpoint() {
|
||||
// OpenAI endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
|
||||
// Anthropic endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
|
|
@ -174,7 +177,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_supported_endpoints() {
|
||||
let endpoints = supported_endpoints();
|
||||
assert_eq!(endpoints.len(), 2);
|
||||
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
|
||||
assert!(endpoints.contains(&"/v1/chat/completions"));
|
||||
assert!(endpoints.contains(&"/v1/messages"));
|
||||
}
|
||||
|
|
@ -217,7 +220,6 @@ mod tests {
|
|||
endpoint
|
||||
);
|
||||
}
|
||||
|
||||
// Total should match
|
||||
assert_eq!(
|
||||
endpoints.len(),
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,12 +4,16 @@
|
|||
pub mod apis;
|
||||
pub mod clients;
|
||||
pub mod providers;
|
||||
pub mod transforms;
|
||||
// Re-export important types and traits
|
||||
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage,
|
||||
ProviderStreamResponseType, TokenUsage,
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
|
|
@ -18,6 +22,8 @@ pub const MESSAGES_PATH: &str = "/v1/messages";
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
|
@ -40,7 +46,7 @@ mod tests {
|
|||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
|
|
@ -77,4 +83,156 @@ mod tests {
|
|||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
///
|
||||
/// This test demonstrates how to:
|
||||
/// 1. Use MessageFrameDecoder to decode AWS Event Stream frames
|
||||
/// 2. Handle chunked network arrivals with buffering
|
||||
/// 3. Extract event types from message headers
|
||||
/// 4. Parse JSON payloads from decoded messages
|
||||
/// 5. Reconstruct streaming content from contentBlockDelta events
|
||||
///
|
||||
/// The decoder handles frame boundaries automatically - you just keep calling
|
||||
/// decode_frame() until it returns Incomplete, which means you've processed
|
||||
/// all complete frames in the buffer.
|
||||
#[test]
|
||||
fn test_amazon_bedrock_streaming_response() {
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Read the response.hex file from tests/e2e directory
|
||||
// Use absolute path to avoid cargo test working directory issues
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
let response_data = fs::read(&test_file)
|
||||
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
|
||||
|
||||
println!("📊 Response data size: {} bytes\n", response_data.len());
|
||||
|
||||
// Create decoder and buffer that implements Buf trait
|
||||
// BytesMut automatically tracks position as decoder advances it!
|
||||
let mut decoder = MessageFrameDecoder::new();
|
||||
let mut simulated_network_buffer = BytesMut::new();
|
||||
let mut frame_count = 0;
|
||||
let mut content_chunks = Vec::new();
|
||||
|
||||
// Simulate chunked network arrivals - process as data comes in
|
||||
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000];
|
||||
let mut offset = 0;
|
||||
let mut chunk_num = 0;
|
||||
|
||||
println!("🔄 Simulating chunked network arrivals...\n");
|
||||
|
||||
// Process chunks as they "arrive" from the network
|
||||
while offset < response_data.len() {
|
||||
// Receive next chunk from network
|
||||
let chunk_size = chunk_sizes[chunk_num % chunk_sizes.len()];
|
||||
let end = (offset + chunk_size).min(response_data.len());
|
||||
let chunk = &response_data[offset..end];
|
||||
|
||||
chunk_num += 1;
|
||||
simulated_network_buffer.extend_from_slice(chunk);
|
||||
offset = end;
|
||||
|
||||
println!(
|
||||
"📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
|
||||
chunk_num,
|
||||
chunk.len(),
|
||||
simulated_network_buffer.len(),
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Try to decode all complete frames from buffer
|
||||
// The Buf trait tracks position automatically!
|
||||
loop {
|
||||
let bytes_before = simulated_network_buffer.remaining();
|
||||
match decoder.decode_frame(&mut simulated_network_buffer) {
|
||||
Ok(DecodedFrame::Complete(message)) => {
|
||||
frame_count += 1;
|
||||
let consumed = bytes_before - simulated_network_buffer.remaining();
|
||||
|
||||
println!(
|
||||
" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
|
||||
frame_count,
|
||||
consumed,
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Get event type from headers
|
||||
let event_type = message
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":event-type")
|
||||
.and_then(|h| {
|
||||
h.value().as_string().ok().map(|s| s.as_str().to_string())
|
||||
});
|
||||
|
||||
if let Some(ref evt) = event_type {
|
||||
println!(" Event: {}", evt);
|
||||
}
|
||||
|
||||
// Parse payload and extract content
|
||||
let payload = message.payload();
|
||||
if !payload.is_empty() {
|
||||
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
|
||||
if event_type.as_deref() == Some("contentBlockDelta") {
|
||||
if let Some(delta) = json.get("delta") {
|
||||
if let Some(text) =
|
||||
delta.get("text").and_then(|t| t.as_str())
|
||||
{
|
||||
println!(" 📝 Content: \"{}\"", text);
|
||||
content_chunks.push(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // Continue loop to check for more complete frames in buffer
|
||||
}
|
||||
Ok(DecodedFrame::Incomplete) => {
|
||||
// Not enough data for a complete frame - need more chunks
|
||||
println!(
|
||||
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
break; // Wait for next chunk
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("❌ Frame decode error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("📋 Summary:");
|
||||
println!(" Total chunks received: {}", chunk_num);
|
||||
println!(" Total frames decoded: {}", frame_count);
|
||||
println!(" Total content chunks: {}", content_chunks.len());
|
||||
println!(
|
||||
" Final buffer remaining: {} bytes",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
if !content_chunks.is_empty() {
|
||||
let full_text = content_chunks.join("");
|
||||
println!("\n📄 Full reconstructed content:");
|
||||
println!("{}", full_text);
|
||||
println!("\n Characters: {}", full_text.len());
|
||||
println!(" Estimated tokens: ~{}", full_text.len() / 4);
|
||||
}
|
||||
|
||||
// Ensure we decoded at least one frame
|
||||
assert!(frame_count > 0, "Should decode at least one frame");
|
||||
|
||||
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
|
||||
assert_eq!(
|
||||
simulated_network_buffer.remaining(),
|
||||
0,
|
||||
"All bytes should be consumed, {} bytes remain",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::{AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
|
|
@ -19,7 +19,8 @@ pub enum ProviderId {
|
|||
Ollama,
|
||||
Moonshotai,
|
||||
Zhipu,
|
||||
Qwen, // alias for Qwen
|
||||
Qwen,
|
||||
AmazonBedrock,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
|
|
@ -39,7 +40,8 @@ impl From<&str> for ProviderId {
|
|||
"ollama" => ProviderId::Ollama,
|
||||
"moonshotai" => ProviderId::Moonshotai,
|
||||
"zhipu" => ProviderId::Zhipu,
|
||||
"qwen" => ProviderId::Qwen, // alias for Zhipu
|
||||
"qwen" => ProviderId::Qwen, // alias for Qwen
|
||||
"amazon_bedrock" => ProviderId::AmazonBedrock,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
|
|
@ -47,16 +49,20 @@ impl From<&str> for ProviderId {
|
|||
|
||||
impl ProviderId {
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
client_api: &SupportedAPIs,
|
||||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(
|
||||
|
|
@ -75,7 +81,7 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
|
|
@ -93,7 +99,27 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// Amazon Bedrock natively supports Bedrock APIs
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -116,6 +142,7 @@ impl Display for ProviderId {
|
|||
ProviderId::Moonshotai => write!(f, "moonshotai"),
|
||||
ProviderId::Zhipu => write!(f, "zhipu"),
|
||||
ProviderId::Qwen => write!(f, "qwen"),
|
||||
ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -10,6 +13,8 @@ use std::fmt;
|
|||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
BedrockConverse(ConverseRequest),
|
||||
BedrockConverseStream(ConverseStreamRequest),
|
||||
//add more request types here
|
||||
}
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
|
|
@ -42,6 +47,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.model(),
|
||||
Self::MessagesRequest(r) => r.model(),
|
||||
Self::BedrockConverse(r) => r.model(),
|
||||
Self::BedrockConverseStream(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +56,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_model(model),
|
||||
Self::MessagesRequest(r) => r.set_model(model),
|
||||
Self::BedrockConverse(r) => r.set_model(model),
|
||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -56,6 +65,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.is_streaming(),
|
||||
Self::MessagesRequest(r) => r.is_streaming(),
|
||||
Self::BedrockConverse(_) => false,
|
||||
Self::BedrockConverseStream(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,6 +74,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||
Self::MessagesRequest(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -70,6 +83,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
|
||||
Self::MessagesRequest(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -77,6 +92,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.to_bytes(),
|
||||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
Self::BedrockConverse(r) => r.to_bytes(),
|
||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,6 +101,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.metadata(),
|
||||
Self::MessagesRequest(r) => r.metadata(),
|
||||
Self::BedrockConverse(r) => r.metadata(),
|
||||
Self::BedrockConverseStream(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,6 +110,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
|
||||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,27 +141,27 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
}
|
||||
|
||||
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
|
||||
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from(
|
||||
(request, upstream_api): (ProviderRequestType, &SupportedAPIs),
|
||||
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (request, upstream_api) {
|
||||
match (client_request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let messages_req =
|
||||
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
|
||||
|
|
@ -155,7 +176,7 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
|||
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
|
|
@ -168,6 +189,69 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
|||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req =
|
||||
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
// Amazon Bedrock to other APIs conversions
|
||||
(ProviderRequestType::BedrockConverse(_), _) => {
|
||||
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
|
||||
}
|
||||
|
||||
(ProviderRequestType::BedrockConverseStream(_), _) => {
|
||||
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -201,7 +285,7 @@ mod tests {
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
231
crates/hermesllm/src/transforms/lib.rs
Normal file
231
crates/hermesllm/src/transforms/lib.rs
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource};
|
||||
use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall};
|
||||
use crate::clients::TransformError;
|
||||
use serde_json::Value;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub trait ExtractText {
|
||||
fn extract_text(&self) -> String;
|
||||
}
|
||||
|
||||
/// Trait for utility functions on content collections
|
||||
pub trait ContentUtils<T> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
}
|
||||
|
||||
/// Helper to create a current unix timestamp
|
||||
pub fn current_timestamp() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
// Content Utilities
|
||||
impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError> {
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for block in self {
|
||||
match block {
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
})
|
||||
}
|
||||
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
|
||||
{
|
||||
let mut content_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
|
||||
for block in self {
|
||||
match block {
|
||||
MessagesContentBlock::Text { text, .. } => {
|
||||
content_parts.push(ContentPart::Text { text: text.clone() });
|
||||
}
|
||||
MessagesContentBlock::Image { source } => {
|
||||
let url = convert_image_source_to_url(source);
|
||||
content_parts.push(ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url,
|
||||
detail: Some("auto".to_string()),
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
..
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
MessagesContentBlock::WebSearchToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::CodeExecutionToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::McpToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
// Skip unsupported content types
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((content_parts, tool_calls, tool_results))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert image source to URL
|
||||
pub fn convert_image_source_to_url(source: &MessagesImageSource) -> String {
|
||||
match source {
|
||||
MessagesImageSource::Base64 { media_type, data } => {
|
||||
format!("data:{};base64,{}", media_type, data)
|
||||
}
|
||||
MessagesImageSource::Url { url } => url.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert image URL to Anthropic image source
|
||||
fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
|
||||
if image_url.url.starts_with("data:") {
|
||||
// Parse data URL
|
||||
let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
|
||||
if parts.len() == 2 {
|
||||
let header = parts[0];
|
||||
let data = parts[1];
|
||||
let media_type = header
|
||||
.strip_prefix("data:")
|
||||
.and_then(|s| s.split(';').next())
|
||||
.unwrap_or("image/jpeg")
|
||||
.to_string();
|
||||
|
||||
MessagesImageSource::Base64 {
|
||||
media_type,
|
||||
data: data.to_string(),
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI message to Anthropic content blocks
|
||||
pub fn convert_openai_message_to_anthropic_content(
|
||||
message: &Message,
|
||||
) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Handle regular content
|
||||
match &message.content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
ContentPart::ImageUrl { image_url } => {
|
||||
let source = convert_image_url_to_source(image_url);
|
||||
blocks.push(MessagesContentBlock::Image { source });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let input: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
blocks.push(MessagesContentBlock::ToolUse {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_call.function.name.clone(),
|
||||
input,
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(blocks)
|
||||
}
|
||||
25
crates/hermesllm/src/transforms/mod.rs
Normal file
25
crates/hermesllm/src/transforms/mod.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
//! API transformation modules
|
||||
//!
|
||||
//! This module provides organized transformations between the two main LLM API formats:
|
||||
//! - `/v1/chat/completions` (OpenAI format)
|
||||
//! - `/v1/messages` (Anthropic format)
|
||||
//!
|
||||
//! Provider-specific transformations (Bedrock, Groq, etc.) are handled internally
|
||||
//! by the gateway, but the external API surface remains these two standard formats.
|
||||
//! The transformations are split into logical modules for maintainability.
|
||||
|
||||
pub mod lib;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use lib::*;
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
// ============================================================================
|
||||
|
||||
/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified
|
||||
pub const DEFAULT_MAX_TOKENS: u32 = 4096;
|
||||
704
crates/hermesllm/src/transforms/request/from_anthropic.rs
Normal file
704
crates/hermesllm/src/transforms/request/from_anthropic.rs
Normal file
|
|
@ -0,0 +1,704 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock,
|
||||
ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration,
|
||||
ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition,
|
||||
ToolUseBlock,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage,
|
||||
ToolResultContent,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message,
|
||||
MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
// Conversion from Anthropic MessagesRequest to OpenAI ChatCompletionsRequest
|
||||
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
|
||||
let mut openai_messages: Vec<Message> = Vec::new();
|
||||
|
||||
// Convert system prompt to system message if present
|
||||
if let Some(system) = req.system {
|
||||
openai_messages.push(system.into());
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for message in req.messages {
|
||||
let converted_messages: Vec<Message> = message.try_into()?;
|
||||
openai_messages.extend(converted_messages);
|
||||
}
|
||||
|
||||
// Convert tools and tool choice
|
||||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let (openai_tool_choice, parallel_tool_calls) =
|
||||
convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages: openai_messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_completion_tokens: Some(req.max_tokens),
|
||||
stream: req.stream,
|
||||
stop: req.stop_sequences,
|
||||
tools: openai_tools,
|
||||
tool_choice: openai_tool_choice,
|
||||
parallel_tool_calls,
|
||||
..Default::default()
|
||||
};
|
||||
_chat_completions_req.suppress_max_tokens_if_o3();
|
||||
_chat_completions_req.fix_temperature_if_gpt5();
|
||||
Ok(_chat_completions_req)
|
||||
}
|
||||
}
|
||||
|
||||
// Conversion from Anthropic MessagesRequest to Amazon Bedrock ConverseRequest
|
||||
impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
|
||||
// Convert system prompt to SystemContentBlock if present
|
||||
let system: Option<Vec<SystemContentBlock>> = req.system.map(|system_prompt| {
|
||||
let text = match system_prompt {
|
||||
MessagesSystemPrompt::Single(text) => text,
|
||||
MessagesSystemPrompt::Blocks(blocks) => blocks.extract_text(),
|
||||
};
|
||||
vec![SystemContentBlock::Text { text }]
|
||||
});
|
||||
|
||||
// Convert messages to Bedrock format
|
||||
let messages = if req.messages.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut bedrock_messages = Vec::new();
|
||||
for anthropic_message in req.messages {
|
||||
let bedrock_message: BedrockMessage = anthropic_message.try_into()?;
|
||||
bedrock_messages.push(bedrock_message);
|
||||
}
|
||||
Some(bedrock_messages)
|
||||
};
|
||||
|
||||
// Build inference configuration
|
||||
// Anthropic always requires max_tokens, so we should always include inferenceConfig
|
||||
let inference_config = Some(InferenceConfiguration {
|
||||
max_tokens: Some(req.max_tokens),
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
stop_sequences: req.stop_sequences,
|
||||
});
|
||||
|
||||
// Convert tools and tool choice to ToolConfiguration
|
||||
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
|
||||
let tools = req.tools.map(|anthropic_tools| {
|
||||
anthropic_tools
|
||||
.into_iter()
|
||||
.map(|tool| BedrockTool::ToolSpec {
|
||||
tool_spec: ToolSpecDefinition {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: ToolInputSchema {
|
||||
json: tool.input_schema,
|
||||
},
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let tool_choice = req.tool_choice.map(|choice| {
|
||||
match choice.kind {
|
||||
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
},
|
||||
MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} },
|
||||
MessagesToolChoiceType::None => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}, // Bedrock doesn't have explicit "none"
|
||||
MessagesToolChoiceType::Tool => {
|
||||
if let Some(name) = choice.name {
|
||||
BedrockToolChoice::Tool {
|
||||
tool: ToolChoiceSpec { name },
|
||||
}
|
||||
} else {
|
||||
BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Some(ToolConfiguration { tools, tool_choice })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ConverseRequest {
|
||||
model_id: req.model,
|
||||
messages,
|
||||
system,
|
||||
inference_config,
|
||||
tool_config,
|
||||
stream: req.stream.unwrap_or(false),
|
||||
guardrail_config: None,
|
||||
additional_model_request_fields: None,
|
||||
additional_model_response_field_paths: None,
|
||||
performance_config: None,
|
||||
prompt_variables: None,
|
||||
request_metadata: None,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Message Conversions
|
||||
impl TryFrom<MessagesMessage> for Vec<Message> {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(message: MessagesMessage) -> Result<Self, Self::Error> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
match message.content {
|
||||
MessagesMessageContent::Single(text) => {
|
||||
result.push(Message {
|
||||
role: message.role.into(),
|
||||
content: MessageContent::Text(text),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
MessagesMessageContent::Blocks(blocks) => {
|
||||
let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?;
|
||||
// Add tool result messages
|
||||
for (tool_use_id, result_text, _is_error) in tool_results {
|
||||
result.push(Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(result_text),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_use_id),
|
||||
});
|
||||
}
|
||||
|
||||
// Only create main message if there's actual content or tool calls
|
||||
// Skip creating empty content messages (e.g., when message only contains tool_result blocks)
|
||||
if !content_parts.is_empty() || !tool_calls.is_empty() {
|
||||
let content = build_openai_content(content_parts, &tool_calls);
|
||||
let main_message = Message {
|
||||
role: message.role.into(),
|
||||
content,
|
||||
name: None,
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
tool_call_id: None,
|
||||
};
|
||||
result.push(main_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Role Conversions
|
||||
impl Into<Role> for MessagesRole {
|
||||
fn into(self) -> Role {
|
||||
match self {
|
||||
MessagesRole::User => Role::User,
|
||||
MessagesRole::Assistant => Role::Assistant,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<MessagesStopReason> for FinishReason {
|
||||
fn into(self) -> MessagesStopReason {
|
||||
match self {
|
||||
FinishReason::Stop => MessagesStopReason::EndTurn,
|
||||
FinishReason::Length => MessagesStopReason::MaxTokens,
|
||||
FinishReason::ToolCalls => MessagesStopReason::ToolUse,
|
||||
FinishReason::ContentFilter => MessagesStopReason::Refusal,
|
||||
FinishReason::FunctionCall => MessagesStopReason::ToolUse,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<MessagesUsage> for Usage {
|
||||
fn into(self) -> MessagesUsage {
|
||||
MessagesUsage {
|
||||
input_tokens: self.prompt_tokens,
|
||||
output_tokens: self.completion_tokens,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// System Prompt Conversions
|
||||
impl Into<Message> for MessagesSystemPrompt {
|
||||
fn into(self) -> Message {
|
||||
let system_content = match self {
|
||||
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
|
||||
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
|
||||
};
|
||||
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: system_content,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Utility Functions
|
||||
/// Convert Anthropic tools to OpenAI format
|
||||
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
strict: None,
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert Anthropic tool choice to OpenAI format
|
||||
fn convert_anthropic_tool_choice(
|
||||
tool_choice: Option<MessagesToolChoice>,
|
||||
) -> (Option<ToolChoice>, Option<bool>) {
|
||||
match tool_choice {
|
||||
Some(choice) => {
|
||||
let openai_choice = match choice.kind {
|
||||
MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto),
|
||||
MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required),
|
||||
MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None),
|
||||
MessagesToolChoiceType::Tool => {
|
||||
if let Some(name) = choice.name {
|
||||
ToolChoice::Function {
|
||||
choice_type: "function".to_string(),
|
||||
function: FunctionChoice { name },
|
||||
}
|
||||
} else {
|
||||
ToolChoice::Type(ToolChoiceType::Auto)
|
||||
}
|
||||
}
|
||||
};
|
||||
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
|
||||
(Some(openai_choice), parallel)
|
||||
}
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build OpenAI message content from parts and tool calls
|
||||
fn build_openai_content(
|
||||
content_parts: Vec<ContentPart>,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> MessageContent {
|
||||
if content_parts.len() == 1 && tool_calls.is_empty() {
|
||||
match &content_parts[0] {
|
||||
ContentPart::Text { text } => MessageContent::Text(text.clone()),
|
||||
_ => MessageContent::Parts(content_parts),
|
||||
}
|
||||
} else if content_parts.is_empty() {
|
||||
MessageContent::Text("".to_string())
|
||||
} else {
|
||||
MessageContent::Parts(content_parts)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<MessagesMessage> for BedrockMessage {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(message: MessagesMessage) -> Result<Self, Self::Error> {
|
||||
let role = match message.role {
|
||||
MessagesRole::User => ConversationRole::User,
|
||||
MessagesRole::Assistant => ConversationRole::Assistant,
|
||||
};
|
||||
|
||||
let mut content_blocks = Vec::new();
|
||||
|
||||
// Convert content blocks
|
||||
match message.content {
|
||||
MessagesMessageContent::Single(text) => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
MessagesMessageContent::Blocks(blocks) => {
|
||||
for block in blocks {
|
||||
match block {
|
||||
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
..
|
||||
} => {
|
||||
content_blocks.push(ContentBlock::ToolUse {
|
||||
tool_use: ToolUseBlock {
|
||||
tool_use_id: id,
|
||||
name,
|
||||
input,
|
||||
},
|
||||
});
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
is_error,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
// Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock
|
||||
let tool_result_content = match content {
|
||||
ToolResultContent::Text(text) => {
|
||||
vec![ToolResultContentBlock::Text { text }]
|
||||
}
|
||||
ToolResultContent::Blocks(blocks) => {
|
||||
let mut result_blocks = Vec::new();
|
||||
for result_block in blocks {
|
||||
match result_block {
|
||||
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
|
||||
result_blocks.push(ToolResultContentBlock::Text { text });
|
||||
}
|
||||
// For now, skip other content types in tool results
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
result_blocks
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure we have at least one content block
|
||||
let final_content = if tool_result_content.is_empty() {
|
||||
vec![ToolResultContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
}]
|
||||
} else {
|
||||
tool_result_content
|
||||
};
|
||||
|
||||
let status = if is_error.unwrap_or(false) {
|
||||
Some(ToolResultStatus::Error)
|
||||
} else {
|
||||
Some(ToolResultStatus::Success)
|
||||
};
|
||||
|
||||
content_blocks.push(ContentBlock::ToolResult {
|
||||
tool_result: ToolResultBlock {
|
||||
tool_use_id,
|
||||
content: final_content,
|
||||
status,
|
||||
},
|
||||
});
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::Image { source } => {
|
||||
// Convert Anthropic image to Bedrock image format
|
||||
match source {
|
||||
crate::apis::anthropic::MessagesImageSource::Base64 {
|
||||
media_type,
|
||||
data,
|
||||
} => {
|
||||
content_blocks.push(ContentBlock::Image {
|
||||
image: ImageBlock {
|
||||
source: ImageSource::Base64 { media_type, data },
|
||||
},
|
||||
});
|
||||
}
|
||||
crate::apis::anthropic::MessagesImageSource::Url { .. } => {
|
||||
// Bedrock doesn't support URL-based images, skip for now
|
||||
// Could potentially download and convert to base64, but not implemented
|
||||
}
|
||||
}
|
||||
}
|
||||
// Skip other content types for now (Thinking, Document, etc.)
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(BedrockMessage {
|
||||
role,
|
||||
content: content_blocks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
|
||||
ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_basic_request() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
system: Some(MessagesSystemPrompt::Single(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)),
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
top_k: None,
|
||||
stream: Some(false),
|
||||
stop_sequences: Some(vec!["STOP".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022");
|
||||
assert!(bedrock_request.system.is_some());
|
||||
assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1);
|
||||
assert!(bedrock_request.messages.is_some());
|
||||
let messages = bedrock_request.messages.as_ref().unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, ConversationRole::User);
|
||||
|
||||
if let ContentBlock::Text { text } = &messages[0].content[0] {
|
||||
assert_eq!(text, "Hello, how are you?");
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
let inference_config = bedrock_request.inference_config.as_ref().unwrap();
|
||||
assert_eq!(inference_config.temperature, Some(0.7));
|
||||
assert_eq!(inference_config.top_p, Some(0.9));
|
||||
assert_eq!(inference_config.max_tokens, Some(1000));
|
||||
assert_eq!(
|
||||
inference_config.stop_sequences,
|
||||
Some(vec!["STOP".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_with_tools() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some("get_weather".to_string()),
|
||||
disable_parallel_tool_use: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022");
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(tool_config.tools.is_some());
|
||||
let tools = tool_config.tools.as_ref().unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
|
||||
assert_eq!(tool_spec.name, "get_weather");
|
||||
assert_eq!(
|
||||
tool_spec.description,
|
||||
Some("Get current weather information".to_string())
|
||||
);
|
||||
|
||||
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
} else {
|
||||
panic!("Expected specific tool choice");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_auto_tool_choice() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Help me with something".to_string()),
|
||||
}],
|
||||
max_tokens: 500,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
|
||||
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(matches!(
|
||||
tool_config.tool_choice,
|
||||
Some(BedrockToolChoice::Auto { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_multi_message_conversation() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello".to_string()),
|
||||
},
|
||||
MessagesMessage {
|
||||
role: MessagesRole::Assistant,
|
||||
content: MessagesMessageContent::Single(
|
||||
"Hi there! How can I help you?".to_string(),
|
||||
),
|
||||
},
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("What's 2+2?".to_string()),
|
||||
},
|
||||
],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
system: Some(MessagesSystemPrompt::Single("Be concise".to_string())),
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
|
||||
|
||||
assert!(bedrock_request.messages.is_some());
|
||||
let messages = bedrock_request.messages.as_ref().unwrap();
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(messages[0].role, ConversationRole::User);
|
||||
assert_eq!(messages[1].role, ConversationRole::Assistant);
|
||||
assert_eq!(messages[2].role, ConversationRole::User);
|
||||
|
||||
// Check system prompt
|
||||
assert!(bedrock_request.system.is_some());
|
||||
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
|
||||
assert_eq!(text, "Be concise");
|
||||
} else {
|
||||
panic!("Expected system text block");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_to_bedrock_conversion() {
|
||||
let anthropic_message = MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Test message".to_string()),
|
||||
};
|
||||
|
||||
let bedrock_message: BedrockMessage = anthropic_message.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_message.role, ConversationRole::User);
|
||||
assert_eq!(bedrock_message.content.len(), 1);
|
||||
|
||||
if let ContentBlock::Text { text } = &bedrock_message.content[0] {
|
||||
assert_eq!(text, "Test message");
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
}
|
||||
}
|
||||
782
crates/hermesllm/src/transforms/request/from_openai.rs
Normal file
782
crates/hermesllm/src/transforms/request/from_openai.rs
Normal file
|
|
@ -0,0 +1,782 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration,
|
||||
Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool,
|
||||
ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema,
|
||||
ToolSpecDefinition,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
|
||||
ToolResultContent,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::transforms::lib::*;
|
||||
use crate::transforms::*;
|
||||
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
// ============================================================================
|
||||
// MAIN REQUEST TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl Into<MessagesSystemPrompt> for Message {
|
||||
fn into(self) -> MessagesSystemPrompt {
|
||||
let system_text = match self.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
MessagesSystemPrompt::Single(system_text)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Message> for MessagesMessage {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(message: Message) -> Result<Self, Self::Error> {
|
||||
let role = match message.role {
|
||||
Role::User => MessagesRole::User,
|
||||
Role::Assistant => MessagesRole::Assistant,
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with tool results
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
return Ok(MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Blocks(vec![
|
||||
MessagesContentBlock::ToolResult {
|
||||
tool_use_id: tool_call_id,
|
||||
is_error: None,
|
||||
content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text {
|
||||
text: message.content.extract_text(),
|
||||
cache_control: None,
|
||||
}]),
|
||||
cache_control: None,
|
||||
},
|
||||
]),
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"System messages should be handled separately".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let content_blocks = convert_openai_message_to_anthropic_content(&message)?;
|
||||
let content = build_anthropic_content(content_blocks);
|
||||
|
||||
Ok(MessagesMessage { role, content })
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Message> for BedrockMessage {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(message: Message) -> Result<Self, Self::Error> {
|
||||
let role = match message.role {
|
||||
Role::User => ConversationRole::User,
|
||||
Role::Assistant => ConversationRole::Assistant,
|
||||
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
|
||||
Role::System => {
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"System messages should be handled separately in Bedrock".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut content_blocks = Vec::new();
|
||||
|
||||
// Handle different message types
|
||||
match message.role {
|
||||
Role::User => {
|
||||
// Convert user message content to content blocks
|
||||
match message.content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
// Convert OpenAI content parts to Bedrock ContentBlocks
|
||||
for part in parts {
|
||||
match part {
|
||||
crate::apis::openai::ContentPart::Text { text } => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
crate::apis::openai::ContentPart::ImageUrl { image_url } => {
|
||||
// Convert image URL to Bedrock image format
|
||||
if image_url.url.starts_with("data:") {
|
||||
if let Some((media_type, data)) =
|
||||
parse_data_url(&image_url.url)
|
||||
{
|
||||
content_blocks.push(ContentBlock::Image {
|
||||
image: crate::apis::amazon_bedrock::ImageBlock {
|
||||
source: crate::apis::amazon_bedrock::ImageSource::Base64 {
|
||||
media_type,
|
||||
data,
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
format!(
|
||||
"Invalid data URL format: {}",
|
||||
image_url.url
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"Only base64 data URLs are supported for images in Bedrock".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
// Handle text content - but only add if non-empty OR if we don't have tool calls
|
||||
let text_content = message.content.extract_text();
|
||||
let has_tool_calls = message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map_or(false, |calls| !calls.is_empty());
|
||||
|
||||
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
|
||||
if !text_content.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text: text_content });
|
||||
} else if !has_tool_calls {
|
||||
// If we have empty content and no tool calls, add a minimal placeholder
|
||||
// This prevents the "blank text field" error
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Convert tool calls to ToolUse content blocks
|
||||
if let Some(tool_calls) = message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
// Parse the arguments string as JSON
|
||||
let input: serde_json::Value =
|
||||
serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
|
||||
TransformError::UnsupportedConversion(format!(
|
||||
"Failed to parse tool arguments as JSON: {}. Arguments: {}",
|
||||
e, tool_call.function.arguments
|
||||
))
|
||||
})?;
|
||||
|
||||
content_blocks.push(ContentBlock::ToolUse {
|
||||
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
|
||||
tool_use_id: tool_call.id,
|
||||
name: tool_call.function.name,
|
||||
input,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Bedrock requires at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with ToolResult content blocks
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_content = message.content.extract_text();
|
||||
|
||||
// Create ToolResult content block
|
||||
let tool_result_content = if tool_content.is_empty() {
|
||||
// Even for tool results, we need non-empty content
|
||||
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
}]
|
||||
} else {
|
||||
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
|
||||
text: tool_content,
|
||||
}]
|
||||
};
|
||||
|
||||
content_blocks.push(ContentBlock::ToolResult {
|
||||
tool_result: crate::apis::amazon_bedrock::ToolResultBlock {
|
||||
tool_use_id: tool_call_id,
|
||||
content: tool_result_content,
|
||||
status: Some(crate::apis::amazon_bedrock::ToolResultStatus::Success), // Default to success
|
||||
},
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
// Already handled above with early return
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
Ok(BedrockMessage {
|
||||
role,
|
||||
content: content_blocks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
|
||||
let mut system_prompt = None;
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for message in req.messages {
|
||||
match message.role {
|
||||
Role::System => {
|
||||
system_prompt = Some(message.into());
|
||||
}
|
||||
_ => {
|
||||
let anthropic_message: MessagesMessage = message.try_into()?;
|
||||
messages.push(anthropic_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tools and tool choice
|
||||
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
|
||||
let anthropic_tool_choice =
|
||||
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
|
||||
Ok(AnthropicMessagesRequest {
|
||||
model: req.model,
|
||||
system: system_prompt,
|
||||
messages,
|
||||
max_tokens: req
|
||||
.max_completion_tokens
|
||||
.or(req.max_tokens)
|
||||
.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: None, // OpenAI doesn't have top_k
|
||||
stream: req.stream,
|
||||
stop_sequences: req.stop,
|
||||
tools: anthropic_tools,
|
||||
tool_choice: anthropic_tool_choice,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
|
||||
// Separate system messages from user/assistant messages
|
||||
let mut system_messages = Vec::new();
|
||||
let mut conversation_messages = Vec::new();
|
||||
|
||||
for message in req.messages {
|
||||
match message.role {
|
||||
Role::System => {
|
||||
let system_text = match message.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
system_messages.push(SystemContentBlock::Text { text: system_text });
|
||||
}
|
||||
_ => {
|
||||
let bedrock_message: BedrockMessage = message.try_into()?;
|
||||
conversation_messages.push(bedrock_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert system messages
|
||||
let system = if system_messages.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system_messages)
|
||||
};
|
||||
|
||||
// Convert conversation messages
|
||||
let messages = if conversation_messages.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(conversation_messages)
|
||||
};
|
||||
|
||||
// Build inference configuration
|
||||
let max_tokens = req.max_completion_tokens.or(req.max_tokens);
|
||||
let inference_config = if max_tokens.is_some()
|
||||
|| req.temperature.is_some()
|
||||
|| req.top_p.is_some()
|
||||
|| req.stop.is_some()
|
||||
{
|
||||
Some(InferenceConfiguration {
|
||||
max_tokens,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
stop_sequences: req.stop,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Convert tools and tool choice to ToolConfiguration
|
||||
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
|
||||
let tools = req.tools.map(|openai_tools| {
|
||||
openai_tools
|
||||
.into_iter()
|
||||
.map(|tool| BedrockTool::ToolSpec {
|
||||
tool_spec: ToolSpecDefinition {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
input_schema: ToolInputSchema {
|
||||
json: tool.function.parameters,
|
||||
},
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let tool_choice = req
|
||||
.tool_choice
|
||||
.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
},
|
||||
ToolChoiceType::Required => {
|
||||
BedrockToolChoice::Any { any: AnyChoice {} }
|
||||
}
|
||||
ToolChoiceType::None => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}, // Bedrock doesn't have explicit "none"
|
||||
},
|
||||
ToolChoice::Function { function, .. } => BedrockToolChoice::Tool {
|
||||
tool: ToolChoiceSpec {
|
||||
name: function.name,
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
.or_else(|| {
|
||||
// If tools are present but no tool_choice specified, default to "auto"
|
||||
if tools.is_some() {
|
||||
Some(BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
Some(ToolConfiguration { tools, tool_choice })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ConverseRequest {
|
||||
model_id: req.model,
|
||||
messages,
|
||||
system,
|
||||
inference_config,
|
||||
tool_config,
|
||||
stream: req.stream.unwrap_or(false),
|
||||
guardrail_config: None,
|
||||
additional_model_request_fields: None,
|
||||
additional_model_response_field_paths: None,
|
||||
performance_config: None,
|
||||
prompt_variables: None,
|
||||
request_metadata: None,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI tools to Anthropic format
|
||||
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| MessagesTool {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
input_schema: tool.function.parameters,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert OpenAI tool choice to Anthropic format
|
||||
fn convert_openai_tool_choice(
|
||||
tool_choice: Option<ToolChoice>,
|
||||
parallel_tool_calls: Option<bool>,
|
||||
) -> Option<MessagesToolChoice> {
|
||||
tool_choice.map(|choice| match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some(function.name),
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Build Anthropic message content from content blocks
|
||||
fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> MessagesMessageContent {
|
||||
if content_blocks.len() == 1 {
|
||||
match &content_blocks[0] {
|
||||
MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()),
|
||||
_ => MessagesMessageContent::Blocks(content_blocks),
|
||||
}
|
||||
} else if content_blocks.is_empty() {
|
||||
MessagesMessageContent::Single("".to_string())
|
||||
} else {
|
||||
MessagesMessageContent::Blocks(content_blocks)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a data URL into media type and base64 data
|
||||
/// Supports format: data:image/jpeg;base64,<data>
|
||||
fn parse_data_url(url: &str) -> Option<(String, String)> {
|
||||
if !url.starts_with("data:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let without_prefix = &url[5..]; // Remove "data:" prefix
|
||||
let parts: Vec<&str> = without_prefix.splitn(2, ',').collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let header = parts[0];
|
||||
let data = parts[1];
|
||||
|
||||
// Parse header: "image/jpeg;base64" or just "image/jpeg"
|
||||
let header_parts: Vec<&str> = header.split(';').collect();
|
||||
if header_parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let media_type = header_parts[0].to_string();
|
||||
|
||||
// Check if it's base64 encoded
|
||||
if header_parts.len() > 1 && header_parts[1] == "base64" {
|
||||
Some((media_type, data.to_string()))
|
||||
} else {
|
||||
// For now, only support base64 encoding
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
|
||||
ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool,
|
||||
ToolChoice, ToolChoiceType,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_bedrock_basic_request() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant.".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello, how are you?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
max_completion_tokens: Some(1000),
|
||||
stop: Some(vec!["STOP".to_string()]),
|
||||
stream: Some(false),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_request.model_id, "gpt-4");
|
||||
assert!(bedrock_request.system.is_some());
|
||||
assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1);
|
||||
|
||||
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
|
||||
assert_eq!(text, "You are a helpful assistant.");
|
||||
} else {
|
||||
panic!("Expected system text block");
|
||||
}
|
||||
|
||||
assert!(bedrock_request.messages.is_some());
|
||||
let messages = bedrock_request.messages.as_ref().unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, ConversationRole::User);
|
||||
|
||||
if let ContentBlock::Text { text } = &messages[0].content[0] {
|
||||
assert_eq!(text, "Hello, how are you?");
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
let inference_config = bedrock_request.inference_config.as_ref().unwrap();
|
||||
assert_eq!(inference_config.temperature, Some(0.7));
|
||||
assert_eq!(inference_config.top_p, Some(0.9));
|
||||
assert_eq!(inference_config.max_tokens, Some(1000));
|
||||
assert_eq!(
|
||||
inference_config.stop_sequences,
|
||||
Some(vec!["STOP".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_bedrock_with_tools() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's the weather like?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_completion_tokens: Some(1000),
|
||||
stop: None,
|
||||
stream: None,
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Function {
|
||||
choice_type: "function".to_string(),
|
||||
function: FunctionChoice {
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_request.model_id, "gpt-4");
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(tool_config.tools.is_some());
|
||||
let tools = tool_config.tools.as_ref().unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
|
||||
assert_eq!(tool_spec.name, "get_weather");
|
||||
assert_eq!(
|
||||
tool_spec.description,
|
||||
Some("Get current weather information".to_string())
|
||||
);
|
||||
|
||||
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
} else {
|
||||
panic!("Expected specific tool choice");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_bedrock_auto_tool_choice() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Help me with something".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_completion_tokens: Some(500),
|
||||
stop: None,
|
||||
stream: None,
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
|
||||
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(matches!(
|
||||
tool_config.tool_choice,
|
||||
Some(BedrockToolChoice::Auto { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_bedrock_multi_message_conversation() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("Be concise".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text("Hi there! How can I help you?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's 2+2?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
max_completion_tokens: Some(100),
|
||||
stop: None,
|
||||
stream: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
|
||||
|
||||
assert!(bedrock_request.messages.is_some());
|
||||
let messages = bedrock_request.messages.as_ref().unwrap();
|
||||
assert_eq!(messages.len(), 3); // System message is separate
|
||||
assert_eq!(messages[0].role, ConversationRole::User);
|
||||
assert_eq!(messages[1].role, ConversationRole::Assistant);
|
||||
assert_eq!(messages[2].role, ConversationRole::User);
|
||||
|
||||
// Check system prompt
|
||||
assert!(bedrock_request.system.is_some());
|
||||
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
|
||||
assert_eq!(text, "Be concise");
|
||||
} else {
|
||||
panic!("Expected system text block");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_message_to_bedrock_conversion() {
|
||||
let openai_message = Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Test message".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
|
||||
let bedrock_message: BedrockMessage = openai_message.try_into().unwrap();
|
||||
|
||||
assert_eq!(bedrock_message.role, ConversationRole::User);
|
||||
assert_eq!(bedrock_message.content.len(), 1);
|
||||
|
||||
if let ContentBlock::Text { text } = &bedrock_message.content[0] {
|
||||
assert_eq!(text, "Test message");
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
}
|
||||
}
|
||||
4
crates/hermesllm/src/transforms/request/mod.rs
Normal file
4
crates/hermesllm/src/transforms/request/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
//! Request transformation modules
|
||||
|
||||
pub mod from_anthropic;
|
||||
pub mod from_openai;
|
||||
3
crates/hermesllm/src/transforms/response/mod.rs
Normal file
3
crates/hermesllm/src/transforms/response/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
//! Response transformation modules
|
||||
pub mod to_anthropic;
|
||||
pub mod to_openai;
|
||||
1051
crates/hermesllm/src/transforms/response/to_anthropic.rs
Normal file
1051
crates/hermesllm/src/transforms/response/to_anthropic.rs
Normal file
File diff suppressed because it is too large
Load diff
1171
crates/hermesllm/src/transforms/response/to_openai.rs
Normal file
1171
crates/hermesllm/src/transforms/response/to_openai.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -23,6 +23,7 @@ thiserror = "1.0.64"
|
|||
derivative = "2.2.0"
|
||||
sha2 = "0.10.8"
|
||||
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
||||
bytes = "1.10"
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = "3.1.1"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -12,8 +13,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|||
use crate::metrics::Metrics;
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::consts::{
|
||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY,
|
||||
REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::llm_providers::LlmProviders;
|
||||
|
|
@ -21,9 +22,15 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent};
|
||||
use hermesllm::apis::sse::{SseEvent, SseStreamIter};
|
||||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::providers::response::{ProviderResponse, SseEvent, SseStreamIter};
|
||||
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
|
||||
use hermesllm::providers::response::ProviderResponse;
|
||||
use hermesllm::{
|
||||
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
|
||||
ProviderStreamResponseType,
|
||||
};
|
||||
|
||||
pub struct StreamContext {
|
||||
metrics: Rc<Metrics>,
|
||||
|
|
@ -33,7 +40,7 @@ pub struct StreamContext {
|
|||
/// The API that is requested by the client (before compatibility mapping)
|
||||
client_api: Option<SupportedAPIs>,
|
||||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedAPIs>,
|
||||
resolved_api: Option<SupportedUpstreamAPIs>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
|
|
@ -45,8 +52,8 @@ pub struct StreamContext {
|
|||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
/// Store upstream response status code to handle error responses gracefully
|
||||
upstream_status_code: Option<StatusCode>,
|
||||
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -75,6 +82,7 @@ impl StreamContext {
|
|||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
upstream_status_code: None,
|
||||
binary_frame_decoder: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -108,6 +116,7 @@ impl StreamContext {
|
|||
.model
|
||||
.as_ref()
|
||||
.unwrap_or(&"".to_string()),
|
||||
self.streaming_response,
|
||||
);
|
||||
if target_endpoint != request_path {
|
||||
self.set_http_request_header(":path", Some(&target_endpoint));
|
||||
|
|
@ -148,14 +157,19 @@ impl StreamContext {
|
|||
|
||||
// Set API-specific headers based on the resolved upstream API
|
||||
match self.resolved_api.as_ref() {
|
||||
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Anthropic API requires x-api-key and anthropic-version headers
|
||||
// Remove any existing Authorization header since Anthropic doesn't use it
|
||||
self.remove_http_request_header("Authorization");
|
||||
self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value));
|
||||
self.set_http_request_header("anthropic-version", Some("2023-06-01"));
|
||||
}
|
||||
Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => {
|
||||
Some(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
// Remove any existing x-api-key header since OpenAI doesn't use it
|
||||
self.remove_http_request_header("x-api-key");
|
||||
|
|
@ -410,7 +424,16 @@ impl StreamContext {
|
|||
match self.client_api.as_ref() {
|
||||
Some(client_api) => {
|
||||
let client_api = client_api.clone(); // Clone to avoid borrowing issues
|
||||
let upstream_api = provider_id.compatible_api_for_client(&client_api);
|
||||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(&client_api, self.streaming_response);
|
||||
|
||||
// Check if this is Bedrock binary stream
|
||||
if matches!(
|
||||
upstream_api,
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||
) {
|
||||
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
|
||||
}
|
||||
|
||||
// Parse body into SSE iterator using TryFrom
|
||||
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
|
||||
|
|
@ -487,6 +510,127 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_bedrock_binary_stream(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
client_api: &SupportedAPIs,
|
||||
upstream_api: &SupportedUpstreamAPIs,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
// Initialize decoder if not present
|
||||
if self.binary_frame_decoder.is_none() {
|
||||
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
|
||||
}
|
||||
|
||||
// Add incoming bytes to buffer
|
||||
let decoder = self.binary_frame_decoder.as_mut().unwrap();
|
||||
decoder.buffer_mut().extend_from_slice(body);
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
loop {
|
||||
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
|
||||
match decoded_frame {
|
||||
Some(DecodedFrame::Complete(ref frame_ref)) => {
|
||||
let frame = DecodedFrame::Complete(frame_ref.clone());
|
||||
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
|
||||
Ok(provider_response) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
// Handle ContentBlockStart and ContentBlockDelta events
|
||||
match &provider_response {
|
||||
ProviderStreamResponseType::MessagesStreamEvent(evt) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index, ..
|
||||
} => {
|
||||
// Mark that we've seen ContentBlockStart for this index
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta {
|
||||
index, ..
|
||||
} => {
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
let needs_start = !self
|
||||
.binary_frame_decoder
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.has_content_block_start_been_sent(*index as i32);
|
||||
|
||||
if needs_start {
|
||||
// Emit empty ContentBlockStart before delta
|
||||
let content_block_start =
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index: *index,
|
||||
content_block: MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let start_sse: String = content_block_start.into();
|
||||
response_buffer
|
||||
.extend_from_slice(start_sse.as_bytes());
|
||||
|
||||
// Mark that we've now sent it
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let sse_string: String = provider_response.into();
|
||||
response_buffer.extend_from_slice(sse_string.as_bytes());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(DecodedFrame::Incomplete) => {
|
||||
// Incomplete frame - buffer retains partial data, wait for more bytes
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
|
||||
self.request_identifier()
|
||||
);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Decode error
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return accumulated complete frames (may be empty if all frames incomplete)
|
||||
Ok(response_buffer)
|
||||
}
|
||||
|
||||
fn handle_non_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
|
|
@ -578,6 +722,11 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.streaming_response = self
|
||||
.get_http_request_header(ARCH_IS_STREAMING_HEADER)
|
||||
.map(|val| val == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
|
|
@ -612,7 +761,17 @@ impl HttpContext for StreamContext {
|
|||
(self.client_api.as_ref(), self.llm_provider.as_ref())
|
||||
{
|
||||
let provider_id = provider.to_provider_id();
|
||||
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
|
||||
self.resolved_api =
|
||||
Some(provider_id.compatible_api_for_client(api, self.streaming_response));
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
|
||||
self.request_identifier(),
|
||||
provider.to_provider_id(),
|
||||
api,
|
||||
self.resolved_api,
|
||||
request_path
|
||||
);
|
||||
} else {
|
||||
self.resolved_api = None;
|
||||
}
|
||||
|
|
@ -697,7 +856,7 @@ impl HttpContext for StreamContext {
|
|||
//We need to deserialize the request body based on the resolved API
|
||||
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
|
||||
Some(the_client_api) => {
|
||||
debug!(
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
|
||||
self.request_identifier(),
|
||||
the_client_api,
|
||||
|
|
@ -795,7 +954,10 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = deserialized_client_request.is_streaming();
|
||||
// If streaming_response is not already set from headers, get it from the parsed request
|
||||
if !self.streaming_response {
|
||||
self.streaming_response = deserialized_client_request.is_streaming();
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = deserialized_client_request.extract_messages_text();
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ llm_providers:
|
|||
access_key: $AZURE_API_KEY
|
||||
base_url: https://katanemo.openai.azure.com
|
||||
|
||||
- model: amazon_bedrock/us.amazon.nova-premier-v1:0
|
||||
access_key: $AWS_BEARER_TOKEN_BEDROCK
|
||||
base_url: https://bedrock-runtime.us-west-2.amazonaws.com
|
||||
|
||||
# Ollama Models
|
||||
- model: ollama/llama3.1
|
||||
base_url: http://host.docker.internal:11434
|
||||
|
|
@ -71,3 +75,6 @@ model_aliases:
|
|||
|
||||
creative-model:
|
||||
target: claude-sonnet-4-20250514
|
||||
|
||||
coding-model:
|
||||
target: us.amazon.nova-premier-v1:0
|
||||
|
|
|
|||
|
|
@ -517,6 +517,36 @@ Azure OpenAI
|
|||
access_key: $AZURE_OPENAI_API_KEY
|
||||
base_url: https://your-resource.openai.azure.com
|
||||
|
||||
Amazon Bedrock
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
**Provider Prefix:** ``amazon_bedrock/``
|
||||
|
||||
**API Endpoint:** Arch automatically constructs the endpoint as:
|
||||
- Non-streaming: ``/model/{model-id}/converse``
|
||||
- Streaming: ``/model/{model-id}/converse-stream``
|
||||
|
||||
**Authentication:** AWS Bearer Token + Base URL - Get your API Keys from `AWS Bedrock Console <https://console.aws.amazon.com/bedrock/>`_ → Discover → API Keys.
|
||||
|
||||
**Supported Chat Models:** All Amazon Bedrock foundation models including Claude (Anthropic), Nova (Amazon), Llama (Meta), Mistral AI, and Cohere Command models.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
llm_providers:
|
||||
# Amazon Nova models
|
||||
- model: amazon_bedrock/us.amazon.nova-premier-v1:0
|
||||
access_key: $AWS_BEARER_TOKEN_BEDROCK
|
||||
base_url: https://bedrock-runtime.us-west-2.amazonaws.com
|
||||
default: true
|
||||
|
||||
- model: amazon_bedrock/us.amazon.nova-pro-v1:0
|
||||
access_key: $AWS_BEARER_TOKEN_BEDROCK
|
||||
base_url: https://bedrock-runtime.us-west-2.amazonaws.com
|
||||
|
||||
# Claude on Bedrock
|
||||
- model: amazon_bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
access_key: $AWS_BEARER_TOKEN_BEDROCK
|
||||
base_url: https://bedrock-runtime.us-west-2.amazonaws.com
|
||||
|
||||
Qwen (Alibaba)
|
||||
~~~~~~~~~~~~~~
|
||||
|
|
@ -540,8 +570,7 @@ Qwen (Alibaba)
|
|||
# Multiple deployments
|
||||
- model: qwen/qwen3-coder
|
||||
access_key: $DASHSCOPE_API_KEY
|
||||
base_url: "https://dashscope-intl.aliyuncs.com",
|
||||
|
||||
base_url: "https://dashscope-intl.aliyuncs.com"
|
||||
|
||||
Ollama
|
||||
~~~~~~
|
||||
|
|
|
|||
BIN
tests/e2e/response.hex
Normal file
BIN
tests/e2e/response.hex
Normal file
Binary file not shown.
BIN
tests/e2e/response_with_tools.hex
Normal file
BIN
tests/e2e/response_with_tools.hex
Normal file
Binary file not shown.
|
|
@ -403,3 +403,381 @@ def test_anthropic_thinking_mode_streaming():
|
|||
final_block_types = [blk.type for blk in final.content]
|
||||
assert "text" in final_block_types
|
||||
assert "thinking" in final_block_types
|
||||
|
||||
|
||||
def test_openai_client_with_coding_model_alias_and_tools():
|
||||
"""Test OpenAI client using 'coding-model' alias (maps to Bedrock) with coding question and tools"""
|
||||
logger.info("Testing OpenAI client with 'coding-model' alias -> Bedrock with tools")
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key",
|
||||
base_url=f"{base_url}/v1",
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_python_code",
|
||||
"description": "Execute Python code and return the result",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
response_content = completion.choices[0].message.content
|
||||
tool_calls = completion.choices[0].message.tool_calls
|
||||
# Should get either text response or tool calls for coding assistance
|
||||
assert response_content is not None or (
|
||||
tool_calls is not None and len(tool_calls) > 0
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_client_with_coding_model_alias_and_tools():
|
||||
"""Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools"""
|
||||
logger.info(
|
||||
"Testing Anthropic client with 'coding-model' alias -> Bedrock with tools"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
|
||||
|
||||
message = client.messages.create(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "run_python_code",
|
||||
"description": "Execute Python code and return the result",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice={"type": "auto"},
|
||||
)
|
||||
|
||||
text_content = "".join(b.text for b in message.content if b.type == "text")
|
||||
tool_use_blocks = [b for b in message.content if b.type == "tool_use"]
|
||||
|
||||
logger.info(f"Response from coding-model alias via Anthropic: {text_content}")
|
||||
logger.info(f"Tool use blocks: {len(tool_use_blocks)}")
|
||||
|
||||
# Should get either text response or tool use blocks for coding assistance
|
||||
assert text_content or len(tool_use_blocks) > 0
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=0) # Disable retries to see the actual failure
|
||||
def test_anthropic_client_with_coding_model_alias_and_tools_streaming():
|
||||
"""Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools - streaming"""
|
||||
logger.info(
|
||||
"Testing Anthropic client with 'coding-model' alias -> Bedrock with tools (streaming)"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
|
||||
|
||||
text_chunks = []
|
||||
tool_use_blocks = []
|
||||
all_events = [] # Capture all events for debugging
|
||||
|
||||
try:
|
||||
with client.messages.stream(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "run_python_code",
|
||||
"description": "Execute Python code and return the result",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice={"type": "auto"},
|
||||
) as stream:
|
||||
for event in stream:
|
||||
# Extract index if available
|
||||
index = getattr(event, "index", None)
|
||||
|
||||
# Log and capture all events for debugging
|
||||
all_events.append(
|
||||
{"type": event.type, "index": index, "event": str(event)[:200]}
|
||||
)
|
||||
logger.info(f"Event #{len(all_events)}: {event.type} [index={index}]")
|
||||
|
||||
# Collect text deltas
|
||||
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
||||
if event.delta.type == "text_delta":
|
||||
text_chunks.append(event.delta.text)
|
||||
|
||||
# Collect tool use blocks
|
||||
if event.type == "content_block_start" and hasattr(
|
||||
event, "content_block"
|
||||
):
|
||||
if event.content_block.type == "tool_use":
|
||||
tool_use_blocks.append(event.content_block)
|
||||
|
||||
final_message = stream.get_final_message()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during streaming: {type(e).__name__}: {e}")
|
||||
logger.error(f"Events received before error: {len(all_events)}")
|
||||
logger.error(f"Text chunks collected: {len(text_chunks)}")
|
||||
logger.error(f"Tool use blocks collected: {len(tool_use_blocks)}")
|
||||
logger.error("\nLast 20 events before crash:")
|
||||
for evt in all_events[-20:]:
|
||||
logger.error(f" {evt['type']:30s} index={evt['index']}")
|
||||
raise
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
logger.info(f"Streaming response from coding-model with tools: {full_text}")
|
||||
logger.info(f"Total events received: {len(all_events)}")
|
||||
logger.info(
|
||||
f"Text chunks: {len(text_chunks)}, Tool use blocks: {len(tool_use_blocks)}"
|
||||
)
|
||||
|
||||
# Should get either text response or tool use blocks for coding assistance
|
||||
# Modified assertion to be more lenient and provide better error messages
|
||||
assert (
|
||||
full_text or len(tool_use_blocks) > 0
|
||||
), f"Expected text or tool use. Got text_len={len(full_text)}, tools={len(tool_use_blocks)}, events={len(all_events)}"
|
||||
|
||||
# Verify final message structure
|
||||
assert final_message is not None, "Final message should not be None"
|
||||
assert (
|
||||
final_message.content and len(final_message.content) > 0
|
||||
), f"Final message should have content. Got: {final_message.content if final_message else 'None'}"
|
||||
|
||||
|
||||
def test_anthropic_client_streaming_with_bedrock():
|
||||
"""Test Anthropic client using 'coding-model' alias (maps to Bedrock) with streaming"""
|
||||
logger.info(
|
||||
"Testing Anthropic client with 'coding-model' alias -> Bedrock (streaming)"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
|
||||
|
||||
text_chunks = []
|
||||
|
||||
with client.messages.stream(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write a short 4-line sonnet about coding.",
|
||||
}
|
||||
],
|
||||
) as stream:
|
||||
for event in stream:
|
||||
# Collect text deltas
|
||||
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
||||
if event.delta.type == "text_delta":
|
||||
text_chunks.append(event.delta.text)
|
||||
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
logger.info(f"Response: {full_text}")
|
||||
|
||||
# Should get a text response
|
||||
assert len(full_text) > 0, "Expected text response from streaming"
|
||||
|
||||
# Verify final message structure
|
||||
assert final_message is not None
|
||||
assert final_message.content and len(final_message.content) > 0
|
||||
|
||||
|
||||
def test_openai_client_streaming_with_bedrock():
|
||||
"""Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming"""
|
||||
logger.info(
|
||||
"Testing OpenAI client with 'coding-model' alias -> Bedrock (streaming)"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key",
|
||||
base_url=f"{base_url}/v1",
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write a short 4-line sonnet about coding.",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
content_chunks = []
|
||||
for chunk in stream:
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
content_chunks.append(delta.content)
|
||||
|
||||
full_content = "".join(content_chunks)
|
||||
logger.info(f"Streaming response from coding-model: {full_content}")
|
||||
|
||||
# Should get a text response
|
||||
assert len(full_content) > 0, "Expected text response from streaming"
|
||||
|
||||
|
||||
def test_openai_client_streaming_with_bedrock_and_tools():
|
||||
"""Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming and tools"""
|
||||
logger.info(
|
||||
"Testing OpenAI client with 'coding-model' alias -> Bedrock with tools (streaming)"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key",
|
||||
base_url=f"{base_url}/v1",
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
|
||||
max_tokens=1000,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?. You should use the tool to run the code.",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_python_code",
|
||||
"description": "Execute Python code and return the result",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
content_chunks = []
|
||||
tool_calls = []
|
||||
chunk_count = 0
|
||||
|
||||
for chunk in stream:
|
||||
chunk_count += 1
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Log what we see in each chunk
|
||||
has_content = delta.content is not None
|
||||
has_tool_calls = delta.tool_calls is not None
|
||||
|
||||
if (
|
||||
chunk_count % 50 == 0 or has_tool_calls
|
||||
): # Log every 50th chunk or any chunk with tool calls
|
||||
logger.info(
|
||||
f"Chunk {chunk_count}: content={has_content}, tool_calls={has_tool_calls}"
|
||||
)
|
||||
if has_tool_calls:
|
||||
logger.info(f" Tool calls in chunk: {delta.tool_calls}")
|
||||
|
||||
# Collect text content
|
||||
if delta.content:
|
||||
content_chunks.append(delta.content)
|
||||
|
||||
# Collect tool calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
# Extend or create tool call entries
|
||||
while len(tool_calls) <= tool_call.index:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
if tool_call.id:
|
||||
tool_calls[tool_call.index]["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_calls[tool_call.index]["function"][
|
||||
"name"
|
||||
] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
tool_calls[tool_call.index]["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
|
||||
full_content = "".join(content_chunks)
|
||||
logger.info(f"Streaming response from coding-model with tools: {full_content}")
|
||||
logger.info(f"Tool calls collected: {len(tool_calls)}")
|
||||
|
||||
if tool_calls:
|
||||
for i, tc in enumerate(tool_calls):
|
||||
logger.info(f" Tool call {i}: {tc['function']['name']}")
|
||||
|
||||
# Should get either text response or tool calls for coding assistance
|
||||
assert (
|
||||
full_content or len(tool_calls) > 0
|
||||
), f"Expected text or tool calls. Got text_len={len(full_content)}, tools={len(tool_calls)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue