add support for v1/messages and transformations (#558)

* pushing draft PR

* transformations are working. Now need to add some tests next

* updated tests and added necessary response transformations for Anthropics' message response object

* fixed bugs for integration tests

* fixed doc tests

* fixed serialization issues with enums on response

* adding some debug logs to help

* fixed issues with non-streaming responses

* updated the stream_context to update response bytes

* the serialized bytes length must be set in the response side

* fixed the debug statement that was causing the integration tests for wasm to fail

* fixing json parsing errors

* intentionally removing the headers

* making sure that we convert the raw bytes to the correct provider type upstream

* fixing non-streaming responses to tranform correctly

* /v1/messages works with transformations to and from /v1/chat/completions

* updating the CLI and demos to support anthropic vs. claude

* adding the anthropic key to the preference based routing tests

* fixed test cases and added more structured logs

* fixed integration tests and cleaned up logs

* added python client tests for anthropic and openai

* cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo

* fixing the tests. python dependency order was broken

* updated the openAI client to fix demos

* removed the raw response debug statement

* fixed the dup cloning issue and cleaned up the ProviderRequestType enum and traits

* fixing logs

* moved away from string literals to consts

* fixed streaming from Anthropic Client to OpenAI

* removed debug statement that would likely trip up integration tests

* fixed integration tests for llm_gateway

* cleaned up test cases and removed unnecessary crates

* fixing comments from PR

* fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-9.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-10.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-41.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
Salman Paracha 2025-09-10 07:40:30 -07:00 committed by GitHub
parent bb71d041a0
commit fb0581fd39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 2842 additions and 919 deletions

View file

@ -1,9 +1,14 @@
use crate::providers::response::TokenUsage;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::clients::transformer::ExtractText;
use crate::{MESSAGES_PATH};
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -17,13 +22,13 @@ pub enum AnthropicApi {
impl ApiDefinition for AnthropicApi {
fn endpoint(&self) -> &'static str {
match self {
AnthropicApi::Messages => "/v1/messages",
AnthropicApi::Messages => MESSAGES_PATH,
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/messages" => Some(AnthropicApi::Messages),
MESSAGES_PATH => Some(AnthropicApi::Messages),
_ => None,
}
}
@ -186,6 +191,19 @@ pub enum MessagesContentBlock {
},
}
impl ExtractText for Vec<MessagesContentBlock> {
fn extract_text(&self) -> String {
self.iter()
.filter_map(|block| match block {
MessagesContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesImageSource {
@ -220,6 +238,15 @@ pub enum MessagesMessageContent {
Blocks(Vec<MessagesContentBlock>),
}
impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String {
match self {
MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text()
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesSystemPrompt {
@ -369,6 +396,121 @@ impl MessagesRequest {
}
}
impl TryFrom<&[u8]> for MessagesRequest {
type Error = serde_json::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes)
}
}
impl TokenUsage for MessagesResponse {
fn completion_tokens(&self) -> usize {
self.usage.output_tokens as usize
}
fn prompt_tokens(&self) -> usize {
self.usage.input_tokens as usize
}
fn total_tokens(&self) -> usize {
(self.usage.input_tokens + self.usage.output_tokens) as usize
}
}
impl ProviderResponse for MessagesResponse {
fn usage(&self) -> Option<&dyn TokenUsage> {
Some(self)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
}
}
impl ProviderRequest for MessagesRequest {
fn model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or(false)
}
fn extract_messages_text(&self) -> String {
let mut text_parts = Vec::new();
// Include system prompt if present
if let Some(system) = &self.system {
match system {
MessagesSystemPrompt::Single(s) => text_parts.push(s.clone()),
MessagesSystemPrompt::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
text_parts.push(text.clone());
}
}
}
}
}
// Extract text from all messages
for message in &self.messages {
match &message.content {
MessagesMessageContent::Single(text) => text_parts.push(text.clone()),
MessagesMessageContent::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
text_parts.push(text.clone());
}
}
}
}
}
text_parts.join(" ")
}
fn get_recent_user_message(&self) -> Option<String> {
// Find the most recent user message
for message in self.messages.iter().rev() {
if message.role == MessagesRole::User {
match &message.content {
MessagesMessageContent::Single(text) => return Some(text.clone()),
MessagesMessageContent::Blocks(blocks) => {
for block in blocks {
if let MessagesContentBlock::Text { text } = block {
return Some(text.clone());
}
}
}
}
}
}
None
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize MessagesRequest: {}", e),
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
}
impl MessagesResponse {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
@ -381,6 +523,54 @@ impl MessagesStreamEvent {
}
}
impl MessagesRole {
pub fn as_str(&self) -> &'static str {
match self {
MessagesRole::User => "user",
MessagesRole::Assistant => "assistant",
}
}
}
// Implement ProviderStreamResponse for MessagesStreamEvent
impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
if let MessagesContentDelta::TextDelta { text } = delta {
Some(text)
} else {
None
}
}
_ => None,
}
}
fn is_final(&self) -> bool {
matches!(self, MessagesStreamEvent::MessageStop)
}
fn role(&self) -> Option<&str> {
match self {
MessagesStreamEvent::MessageStart { message } => Some(message.role.as_str()),
_ => None,
}
}
fn event_type(&self) -> Option<&str> {
Some(match self {
MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
MessagesStreamEvent::MessageStop => "message_stop",
MessagesStreamEvent::Ping => "ping",
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -878,13 +1068,13 @@ mod tests {
let api = AnthropicApi::Messages;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/messages");
assert_eq!(api.endpoint(), MESSAGES_PATH);
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint trait method
let found_api = AnthropicApi::from_endpoint("/v1/messages");
let found_api = AnthropicApi::from_endpoint(MESSAGES_PATH);
assert_eq!(found_api, Some(AnthropicApi::Messages));
let not_found = AnthropicApi::from_endpoint("/v1/unknown");