plano/crates/hermesllm/src/providers/response.rs
2026-04-17 14:03:47 -07:00

453 lines
18 KiB
Rust

use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai_responses::ResponsesAPIResponse;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::providers::id::ProviderId;
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse),
MessagesResponse(MessagesResponse),
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
}
/// Trait for token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
/// Tokens served from a prompt cache read (OpenAI `prompt_tokens_details.cached_tokens`,
/// Anthropic `cache_read_input_tokens`, Google `cached_content_token_count`).
fn cached_input_tokens(&self) -> Option<usize> {
None
}
/// Tokens used to write a cache entry (Anthropic `cache_creation_input_tokens`).
fn cache_creation_tokens(&self) -> Option<usize> {
None
}
/// Reasoning tokens for reasoning models (OpenAI `completion_tokens_details.reasoning_tokens`,
/// Google `thoughts_token_count`).
fn reasoning_tokens(&self) -> Option<usize> {
None
}
}
/// Rich usage breakdown extracted from a provider response.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UsageDetails {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub cached_input_tokens: Option<usize>,
pub cache_creation_tokens: Option<usize>,
pub reasoning_tokens: Option<usize>,
}
pub trait ProviderResponse: Send + Sync {
/// Get usage information if available - returns dynamic trait object
fn usage(&self) -> Option<&dyn TokenUsage>;
/// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage()
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
/// Extract a rich usage breakdown including cached/cache-creation/reasoning tokens.
fn extract_usage_details(&self) -> Option<UsageDetails> {
self.usage().map(|u| UsageDetails {
prompt_tokens: u.prompt_tokens(),
completion_tokens: u.completion_tokens(),
total_tokens: u.total_tokens(),
cached_input_tokens: u.cached_input_tokens(),
cache_creation_tokens: u.cache_creation_tokens(),
reasoning_tokens: u.reasoning_tokens(),
})
}
}
impl ProviderResponse for ProviderResponseType {
fn usage(&self) -> Option<&dyn TokenUsage> {
match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
ProviderResponseType::MessagesResponse(resp) => resp.usage(),
ProviderResponseType::ResponsesAPIResponse(resp) => {
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
}
}
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| {
(
u.input_tokens as usize,
u.output_tokens as usize,
u.total_tokens as usize,
)
}),
}
}
}
// --- Response transformation logic for client API compatibility ---
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error;
fn try_from(
(bytes, client_api, provider_id): (&[u8], &SupportedAPIsFromClient, &ProviderId),
) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api, false);
match (&upstream_api, client_api) {
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
}
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => {
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
) => {
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
// Amazon Bedrock transformations
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
(
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(resp)))
}
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
let chat_completions_response: ChatCompletionsResponse =
ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ResponsesAPI format using the transformer
let responses_resp: ResponsesAPIResponse =
chat_completions_response.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
responses_resp,
)))
}
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
//Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
}
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
// Chain transform: Bedrock Converse -> ChatCompletions -> ResponsesAPI
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ChatCompletions format
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Bedrock to ChatCompletions transformation error: {}", e),
)
})?;
// Transform to ResponsesAPI format
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"ChatCompletions to ResponsesAPI transformation error: {}",
e
),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unsupported API combination for response transformation",
)),
}
}
}
#[derive(Debug)]
pub struct ProviderResponseError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message)
}
}
impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::providers::id::ProviderId;
use serde_json::json;
#[test]
fn test_openai_response_from_bytes() {
let resp = json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{
"index": 0,
"message": { "role": "assistant", "content": "Hello!" },
"finish_reason": "stop"
}
],
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.choices.len(), 1);
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
#[test]
fn test_anthropic_response_from_bytes() {
let resp = json!({
"id": "msg_01ABC123",
"type": "message",
"role": "assistant",
"content": [
{ "type": "text", "text": "Hello! How can I help you today?" }
],
"model": "claude-3-sonnet-20240229",
"stop_reason": "end_turn",
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.content.len(), 1);
}
_ => panic!("Expected MessagesResponse variant"),
}
}
#[test]
fn test_anthropic_response_from_bytes_with_openai_provider() {
// OpenAI provider receives OpenAI response but client expects Anthropic format
// Upstream API = OpenAI, Client API = Anthropic -> parse OpenAI, convert to Anthropic
let resp = json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{
"index": 0,
"message": { "role": "assistant", "content": "Hello! How can I help you today?" },
"finish_reason": "stop"
}
],
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.usage.input_tokens, 10);
assert_eq!(r.usage.output_tokens, 25);
}
_ => panic!("Expected MessagesResponse variant"),
}
}
#[test]
fn test_openai_response_from_bytes_with_claude_provider() {
// Claude provider using OpenAI-compatible API returns OpenAI format response
// Client API = OpenAI, Provider = Anthropic -> Anthropic returns OpenAI format via their compatible API
let resp = json!({
"id": "chatcmpl-01ABC123",
"object": "chat.completion",
"created": 1677652288,
"model": "claude-3-sonnet-20240229",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.usage.prompt_tokens, 10);
assert_eq!(r.usage.completion_tokens, 25);
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
}