add support for v1/messages and transformations (#558)

* pushing draft PR

* transformations are working. Now need to add some tests next

* updated tests and added necessary response transformations for Anthropics' message response object

* fixed bugs for integration tests

* fixed doc tests

* fixed serialization issues with enums on response

* adding some debug logs to help

* fixed issues with non-streaming responses

* updated the stream_context to update response bytes

* the serialized bytes length must be set in the response side

* fixed the debug statement that was causing the integration tests for wasm to fail

* fixing json parsing errors

* intentionally removing the headers

* making sure that we convert the raw bytes to the correct provider type upstream

* fixing non-streaming responses to tranform correctly

* /v1/messages works with transformations to and from /v1/chat/completions

* updating the CLI and demos to support anthropic vs. claude

* adding the anthropic key to the preference based routing tests

* fixed test cases and added more structured logs

* fixed integration tests and cleaned up logs

* added python client tests for anthropic and openai

* cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo

* fixing the tests. python dependency order was broken

* updated the openAI client to fix demos

* removed the raw response debug statement

* fixed the dup cloning issue and cleaned up the ProviderRequestType enum and traits

* fixing logs

* moved away from string literals to consts

* fixed streaming from Anthropic Client to OpenAI

* removed debug statement that would likely trip up integration tests

* fixed integration tests for llm_gateway

* cleaned up test cases and removed unnecessary crates

* fixing comments from PR

* fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-9.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-10.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-41.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
Salman Paracha 2025-09-10 07:40:30 -07:00 committed by GitHub
parent bb71d041a0
commit fb0581fd39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 2842 additions and 919 deletions

View file

@ -1,9 +1,14 @@
use crate::providers::response::TokenUsage;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::clients::transformer::ExtractText;
use crate::{MESSAGES_PATH};
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -17,13 +22,13 @@ pub enum AnthropicApi {
impl ApiDefinition for AnthropicApi {
fn endpoint(&self) -> &'static str {
match self {
AnthropicApi::Messages => "/v1/messages",
AnthropicApi::Messages => MESSAGES_PATH,
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/messages" => Some(AnthropicApi::Messages),
MESSAGES_PATH => Some(AnthropicApi::Messages),
_ => None,
}
}
@ -186,6 +191,19 @@ pub enum MessagesContentBlock {
},
}
impl ExtractText for Vec<MessagesContentBlock> {
fn extract_text(&self) -> String {
self.iter()
.filter_map(|block| match block {
MessagesContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesImageSource {
@ -220,6 +238,15 @@ pub enum MessagesMessageContent {
Blocks(Vec<MessagesContentBlock>),
}
impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String {
match self {
MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text()
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesSystemPrompt {
@ -369,6 +396,121 @@ impl MessagesRequest {
}
}
impl TryFrom<&[u8]> for MessagesRequest {
type Error = serde_json::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes)
}
}
impl TokenUsage for MessagesResponse {
fn completion_tokens(&self) -> usize {
self.usage.output_tokens as usize
}
fn prompt_tokens(&self) -> usize {
self.usage.input_tokens as usize
}
fn total_tokens(&self) -> usize {
(self.usage.input_tokens + self.usage.output_tokens) as usize
}
}
impl ProviderResponse for MessagesResponse {
fn usage(&self) -> Option<&dyn TokenUsage> {
Some(self)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
}
}
impl ProviderRequest for MessagesRequest {
fn model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or(false)
}
fn extract_messages_text(&self) -> String {
let mut text_parts = Vec::new();
// Include system prompt if present
if let Some(system) = &self.system {
match system {
MessagesSystemPrompt::Single(s) => text_parts.push(s.clone()),
MessagesSystemPrompt::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
text_parts.push(text.clone());
}
}
}
}
}
// Extract text from all messages
for message in &self.messages {
match &message.content {
MessagesMessageContent::Single(text) => text_parts.push(text.clone()),
MessagesMessageContent::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
text_parts.push(text.clone());
}
}
}
}
}
text_parts.join(" ")
}
fn get_recent_user_message(&self) -> Option<String> {
// Find the most recent user message
for message in self.messages.iter().rev() {
if message.role == MessagesRole::User {
match &message.content {
MessagesMessageContent::Single(text) => return Some(text.clone()),
MessagesMessageContent::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
return Some(text.clone());
}
}
}
}
}
}
None
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize MessagesRequest: {}", e),
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
}
impl MessagesResponse {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
@ -381,6 +523,54 @@ impl MessagesStreamEvent {
}
}
impl MessagesRole {
pub fn as_str(&self) -> &'static str {
match self {
MessagesRole::User => "user",
MessagesRole::Assistant => "assistant",
}
}
}
// Implement ProviderStreamResponse for MessagesStreamEvent
impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
if let MessagesContentDelta::TextDelta { text } = delta {
Some(text)
} else {
None
}
}
_ => None,
}
}
fn is_final(&self) -> bool {
matches!(self, MessagesStreamEvent::MessageStop)
}
fn role(&self) -> Option<&str> {
match self {
MessagesStreamEvent::MessageStart { message } => Some(message.role.as_str()),
_ => None,
}
}
fn event_type(&self) -> Option<&str> {
Some(match self {
MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
MessagesStreamEvent::MessageStop => "message_stop",
MessagesStreamEvent::Ping => "ping",
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -878,13 +1068,13 @@ mod tests {
let api = AnthropicApi::Messages;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/messages");
assert_eq!(api.endpoint(), MESSAGES_PATH);
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint trait method
let found_api = AnthropicApi::from_endpoint("/v1/messages");
let found_api = AnthropicApi::from_endpoint(MESSAGES_PATH);
assert_eq!(found_api, Some(AnthropicApi::Messages));
let not_found = AnthropicApi::from_endpoint("/v1/unknown");

View file

@ -1,110 +1,9 @@
pub mod anthropic;
pub mod openai;
// Re-export all types for convenience
pub use anthropic::*;
pub use openai::*;
/// Common trait that all API definitions must implement
///
/// This trait ensures consistency across different AI provider API definitions
/// and makes it easy to add new providers like Gemini, Claude, etc.
///
/// Note: This is different from the `ApiProvider` enum in `clients::endpoints`
/// which represents provider identification, while this trait defines API capabilities.
///
/// # Benefits
///
/// - **Consistency**: All API providers implement the same interface
/// - **Extensibility**: Easy to add new providers without breaking existing code
/// - **Type Safety**: Compile-time guarantees that all providers implement required methods
/// - **Discoverability**: Clear documentation of what capabilities each API supports
///
/// # Example implementation for a new provider:
///
/// ```rust,ignore
/// use serde::{Deserialize, Serialize};
/// use super::ApiDefinition;
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub enum GeminiApi {
/// GenerateContent,
/// ChatCompletions,
/// }
///
/// impl GeminiApi {
/// pub fn endpoint(&self) -> &'static str {
/// match self {
/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent",
/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat",
/// }
/// }
///
/// pub fn from_endpoint(endpoint: &str) -> Option<Self> {
/// match endpoint {
/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent),
/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions),
/// _ => None,
/// }
/// }
///
/// pub fn supports_streaming(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => true,
/// }
/// }
///
/// pub fn supports_tools(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
///
/// pub fn supports_vision(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
/// }
///
/// impl ApiDefinition for GeminiApi {
/// fn endpoint(&self) -> &'static str {
/// self.endpoint()
/// }
///
/// fn from_endpoint(endpoint: &str) -> Option<Self> {
/// Self::from_endpoint(endpoint)
/// }
///
/// fn supports_streaming(&self) -> bool {
/// self.supports_streaming()
/// }
///
/// fn supports_tools(&self) -> bool {
/// self.supports_tools()
/// }
///
/// fn supports_vision(&self) -> bool {
/// self.supports_vision()
/// }
/// }
///
/// // Now you can use generic code that works with any API:
/// fn print_api_info<T: ApiDefinition>(api: &T) {
/// println!("Endpoint: {}", api.endpoint());
/// println!("Supports streaming: {}", api.supports_streaming());
/// println!("Supports tools: {}", api.supports_tools());
/// println!("Supports vision: {}", api.supports_vision());
/// }
///
/// // Works with both OpenAI and Anthropic (and future Gemini)
/// print_api_info(&OpenAIApi::ChatCompletions);
/// print_api_info(&AnthropicApi::Messages);
/// print_api_info(&GeminiApi::GenerateContent);
/// ```
pub trait ApiDefinition {
/// Returns the endpoint path for this API
fn endpoint(&self) -> &'static str;
@ -132,6 +31,7 @@ pub trait ApiDefinition {
#[cfg(test)]
mod tests {
use super::*;
use crate::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
#[test]
fn test_generic_api_functionality() {
@ -150,8 +50,8 @@ mod tests {
fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait
let endpoints = vec![
"/v1/chat/completions",
"/v1/messages",
CHAT_COMPLETIONS_PATH,
MESSAGES_PATH,
"/v1/unknown"
];

View file

@ -5,11 +5,11 @@ use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use super::ApiDefinition;
use crate::clients::transformer::{ExtractText};
use crate::{CHAT_COMPLETIONS_PATH};
// ============================================================================
// OPENAI API ENUMERATION
@ -28,13 +28,13 @@ pub enum OpenAIApi {
impl ApiDefinition for OpenAIApi {
fn endpoint(&self) -> &'static str {
match self {
OpenAIApi::ChatCompletions => "/v1/chat/completions",
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/chat/completions" => Some(OpenAIApi::ChatCompletions),
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
_ => None,
}
}
@ -81,7 +81,7 @@ pub struct ChatCompletionsRequest {
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
pub max_tokens: Option<u32>,
pub modalities: Option<Vec<String>>,
pub metadata: Option<HashMap<String, String>>,
pub metadata: Option<HashMap<String, Value>>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
@ -174,6 +174,28 @@ pub enum MessageContent {
Parts(Vec<ContentPart>),
}
// Content Extraction
impl ExtractText for MessageContent {
fn extract_text(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.extract_text()
}
}
}
impl ExtractText for Vec<ContentPart> {
fn extract_text(&self) -> String {
self.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
}
impl Display for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -328,6 +350,7 @@ pub struct ChatCompletionsResponse {
pub choices: Vec<Choice>,
pub usage: Usage,
pub system_fingerprint: Option<String>,
pub service_tier: Option<String>,
}
/// Finish reason for completion
@ -576,6 +599,18 @@ impl ProviderRequest for ChatCompletionsRequest {
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
}
/// Implementation of ProviderResponse for ChatCompletionsResponse
@ -593,68 +628,6 @@ impl ProviderResponse for ChatCompletionsResponse {
}
}
// ============================================================================
// OPENAI SSE STREAMING ITERATOR
// ============================================================================
/// OpenAI-specific SSE streaming iterator
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
pub struct OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
sse_stream: SseStreamIter<I>,
}
impl<I> OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
Self { sse_stream }
}
}
impl<I> Iterator for OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.sse_stream.lines {
let line = line.as_ref();
if line.is_empty() {
continue;
}
if line.starts_with("data: ") {
let data = &line[6..]; // Remove "data: " prefix
if data == "[DONE]" {
return None;
}
// Skip ping messages (usually from other providers, but handle gracefully)
if data == r#"{"type": "ping"}"# {
continue;
}
// OpenAI-specific parsing of ChatCompletionsStreamResponse
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
Ok(response) => return Some(Ok(Box::new(response))),
Err(e) => return Some(Err(Box::new(
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
))),
}
}
}
None
}
}
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
fn content_delta(&self) -> Option<&str> {
@ -680,6 +653,10 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
Role::Tool => "tool",
}))
}
fn event_type(&self) -> Option<&str> {
None // OpenAI doesn't use event types in SSE
}
}
@ -982,13 +959,13 @@ mod tests {
let api = OpenAIApi::ChatCompletions;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/chat/completions");
assert_eq!(api.endpoint(), CHAT_COMPLETIONS_PATH);
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint
let found_api = OpenAIApi::from_endpoint("/v1/chat/completions");
let found_api = OpenAIApi::from_endpoint(CHAT_COMPLETIONS_PATH);
assert_eq!(found_api, Some(OpenAIApi::ChatCompletions));
let not_found = OpenAIApi::from_endpoint("/v1/unknown");
@ -1139,4 +1116,84 @@ mod tests {
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
assert!(invalid_result.is_err());
}
#[test]
fn test_chat_completions_response_with_service_tier() {
// Test that ChatCompletionsResponse can deserialize OpenAI responses with service_tier field
let json_response = r#"{
"id": "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2",
"object": "chat.completion",
"created": 1756574706,
"model": "gpt-4o-2024-08-06",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response content",
"annotations": []
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 65,
"completion_tokens": 184,
"total_tokens": 249,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"service_tier": "default",
"system_fingerprint": "fp_f33640a400"
}"#;
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
assert_eq!(response.id, "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2");
assert_eq!(response.object, "chat.completion");
assert_eq!(response.created, 1756574706);
assert_eq!(response.model, "gpt-4o-2024-08-06");
assert_eq!(response.service_tier, Some("default".to_string()));
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
assert_eq!(response.choices.len(), 1);
assert_eq!(response.usage.prompt_tokens, 65);
assert_eq!(response.usage.completion_tokens, 184);
assert_eq!(response.usage.total_tokens, 249);
}
#[test]
fn test_chat_completions_response_without_service_tier() {
// Test that ChatCompletionsResponse can deserialize responses without service_tier (backward compatibility)
let json_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#;
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
assert_eq!(response.id, "chatcmpl-123");
assert_eq!(response.service_tier, None); // Should be None when not present
assert_eq!(response.system_fingerprint, None);
}
}