add support for v1/messages and transformations (#558)

* pushing draft PR

* transformations are working. Now need to add some tests next

* updated tests and added necessary response transformations for Anthropics' message response object

* fixed bugs for integration tests

* fixed doc tests

* fixed serialization issues with enums on response

* adding some debug logs to help

* fixed issues with non-streaming responses

* updated the stream_context to update response bytes

* the serialized bytes length must be set in the response side

* fixed the debug statement that was causing the integration tests for wasm to fail

* fixing json parsing errors

* intentionally removing the headers

* making sure that we convert the raw bytes to the correct provider type upstream

* fixing non-streaming responses to tranform correctly

* /v1/messages works with transformations to and from /v1/chat/completions

* updating the CLI and demos to support anthropic vs. claude

* adding the anthropic key to the preference based routing tests

* fixed test cases and added more structured logs

* fixed integration tests and cleaned up logs

* added python client tests for anthropic and openai

* cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo

* fixing the tests. python dependency order was broken

* updated the openAI client to fix demos

* removed the raw response debug statement

* fixed the dup cloning issue and cleaned up the ProviderRequestType enum and traits

* fixing logs

* moved away from string literals to consts

* fixed streaming from Anthropic Client to OpenAI

* removed debug statement that would likely trip up integration tests

* fixed integration tests for llm_gateway

* cleaned up test cases and removed unnecessary crates

* fixing comments from PR

* fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-9.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-10.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-41.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
Salman Paracha 2025-09-10 07:40:30 -07:00 committed by GitHub
parent bb71d041a0
commit fb0581fd39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 2842 additions and 919 deletions

View file

@ -1,41 +1,17 @@
use crate::apis::openai::ChatCompletionsRequest;
use super::{ProviderId, get_provider_config, AdapterType};
use crate::apis::anthropic::MessagesRequest;
use crate::clients::endpoints::SupportedAPIs;
use serde_json::Value;
use std::error::Error;
use std::fmt;
use std::collections::HashMap;
#[derive(Clone)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
//MessagesRequest(MessagesRequest),
MessagesRequest(MessagesRequest),
//add more request types here
}
impl TryFrom<&[u8]> for ProviderRequestType {
type Error = std::io::Error;
// if passing bytes without provider id we assume the request is in OpenAI format
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
}
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
// Future: handle other adapter types like Claude
}
}
}
pub trait ProviderRequest: Send + Sync {
/// Extract the model name from the request
fn model(&self) -> &str;
@ -54,46 +30,129 @@ pub trait ProviderRequest: Send + Sync {
/// Convert the request to bytes for transmission
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
fn metadata(&self) -> &Option<HashMap<String, Value>>;
/// Remove a metadata key from the request and return true if the key was present
fn remove_metadata_key(&mut self, key: &str) -> bool;
}
impl ProviderRequest for ProviderRequestType {
fn model(&self) -> &str {
match self {
Self::ChatCompletionsRequest(r) => r.model(),
Self::MessagesRequest(r) => r.model(),
}
}
fn set_model(&mut self, model: String) {
match self {
Self::ChatCompletionsRequest(r) => r.set_model(model),
Self::MessagesRequest(r) => r.set_model(model),
}
}
fn is_streaming(&self) -> bool {
match self {
Self::ChatCompletionsRequest(r) => r.is_streaming(),
Self::MessagesRequest(r) => r.is_streaming(),
}
}
fn extract_messages_text(&self) -> String {
match self {
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
Self::MessagesRequest(r) => r.extract_messages_text(),
}
}
fn get_recent_user_message(&self) -> Option<String> {
match self {
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
Self::MessagesRequest(r) => r.get_recent_user_message(),
}
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
match self {
Self::ChatCompletionsRequest(r) => r.to_bytes(),
Self::MessagesRequest(r) => r.to_bytes(),
}
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
match self {
Self::ChatCompletionsRequest(r) => r.metadata(),
Self::MessagesRequest(r) => r.metadata(),
}
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
match self {
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
Self::MessagesRequest(r) => r.remove_metadata_key(key),
}
}
}
/// Parse the client API from a byte slice.
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
match client_api {
SupportedAPIs::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
}
}
}
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
type Error = ProviderRequestError;
fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
match (request, upstream_api) {
// Same API - no conversion needed, just clone the reference
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
// Cross-API conversion - cloning is necessary for transformation
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let messages_req = MessagesRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
}
}
}
/// Error types for provider operations
#[derive(Debug)]
@ -113,3 +172,194 @@ impl Error for ProviderRequestError {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest};
use crate::clients::transformer::ExtractText;
use serde_json::json;
#[test]
fn test_openai_request_from_bytes() {
let req = json!({
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello!"}
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
#[test]
fn test_anthropic_request_from_bytes_with_endpoint() {
let req = json!({
"model": "claude-3-sonnet",
"system": "You are a helpful assistant",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "Hello!"}
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::MessagesRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
_ => panic!("Expected MessagesRequest variant"),
}
}
#[test]
fn test_openai_request_from_bytes_with_endpoint() {
let req = json!({
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello!"}
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
#[test]
fn test_anthropic_request_from_bytes_wrong_endpoint() {
let req = json!({
"model": "claude-3-sonnet",
"system": "You are a helpful assistant",
"messages": [
{"role": "user", "content": "Hello!"}
]
});
let bytes = serde_json::to_vec(&req).unwrap();
// Intentionally use OpenAI endpoint for Anthropic payload
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
// Should parse as ChatCompletionsRequest, not error
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
#[test]
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
messages: vec![
crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
}
],
max_tokens: 128,
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: Some(0.7),
top_p: Some(1.0),
top_k: None,
stream: Some(false),
stop_sequences: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
metadata: None,
};
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
assert_eq!(anthropic_req.model, anthropic_req2.model);
// Compare system prompt text if present
assert_eq!(
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
);
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
// Compare message content text if present
assert_eq!(
anthropic_req.messages[0].content.extract_text(),
anthropic_req2.messages[0].content.extract_text()
);
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
}
#[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
],
temperature: Some(0.7),
top_p: Some(1.0),
max_tokens: Some(128),
stream: Some(false),
stop: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
..Default::default()
};
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
assert_eq!(openai_req.max_tokens, openai_req2.max_tokens);
}
}