add native Gemini provider support via hermesllm transforms

This commit is contained in:
Adil Hafeez 2026-03-12 12:27:38 +00:00
parent 5400b0a2fa
commit 053108b96c
16 changed files with 2416 additions and 10 deletions

View file

@ -53,7 +53,9 @@ pub async fn router_chat_get_upstream_model(
ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_)
| ProviderRequestType::BedrockConverseStream(_)
| ProviderRequestType::ResponsesAPIRequest(_),
| ProviderRequestType::ResponsesAPIRequest(_)
| ProviderRequestType::GeminiGenerateContent(_)
| ProviderRequestType::GeminiStreamGenerateContent(_),
) => {
warn!("unexpected: got non-ChatCompletions request after converting to OpenAI format");
return Err(RoutingError::internal_error(

View file

@ -0,0 +1,744 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::TokenUsage;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::transforms::lib::ExtractText;
use crate::GENERATE_CONTENT_PATH_SUFFIX;
// ============================================================================
// GEMINI GENERATE CONTENT API ENUMERATION
// ============================================================================
/// Enum for all supported Gemini GenerateContent APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GeminiApi {
GenerateContent,
StreamGenerateContent,
}
impl ApiDefinition for GeminiApi {
fn endpoint(&self) -> &'static str {
match self {
GeminiApi::GenerateContent => ":generateContent",
GeminiApi::StreamGenerateContent => ":streamGenerateContent",
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
if endpoint.ends_with(":streamGenerateContent") {
Some(GeminiApi::StreamGenerateContent)
} else if endpoint.ends_with(GENERATE_CONTENT_PATH_SUFFIX) {
Some(GeminiApi::GenerateContent)
} else {
None
}
}
fn supports_streaming(&self) -> bool {
match self {
GeminiApi::GenerateContent => false,
GeminiApi::StreamGenerateContent => true,
}
}
fn supports_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
true
}
fn all_variants() -> Vec<Self> {
vec![GeminiApi::GenerateContent, GeminiApi::StreamGenerateContent]
}
}
// ============================================================================
// REQUEST TYPES
// ============================================================================
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
/// Internal model field — not part of Gemini wire format (model is in the URL).
/// Populated during parsing and used for routing.
#[serde(skip_serializing, default)]
pub model: String,
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tool>>,
pub tool_config: Option<ToolConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
pub system_instruction: Option<Content>,
pub cached_content: Option<String>,
#[serde(skip_serializing)]
pub metadata: Option<HashMap<String, Value>>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
pub role: Option<String>,
pub parts: Vec<Part>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Part {
pub text: Option<String>,
pub inline_data: Option<InlineData>,
pub function_call: Option<FunctionCall>,
pub function_response: Option<FunctionResponse>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineData {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCall {
pub name: String,
pub args: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponse {
pub name: String,
pub response: Value,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub max_output_tokens: Option<u32>,
pub stop_sequences: Option<Vec<String>>,
pub response_mime_type: Option<String>,
pub candidate_count: Option<u32>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Option<Vec<FunctionDeclaration>>,
pub code_execution: Option<Value>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: String,
pub threshold: String,
}
// ============================================================================
// RESPONSE TYPES
// ============================================================================
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub candidates: Option<Vec<Candidate>>,
pub usage_metadata: Option<UsageMetadata>,
pub prompt_feedback: Option<PromptFeedback>,
pub model_version: Option<String>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Option<Content>,
pub finish_reason: Option<String>,
pub safety_ratings: Option<Vec<SafetyRating>>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub prompt_token_count: Option<u32>,
pub candidates_token_count: Option<u32>,
pub total_token_count: Option<u32>,
}
impl TokenUsage for UsageMetadata {
fn completion_tokens(&self) -> usize {
self.candidates_token_count.unwrap_or(0) as usize
}
fn prompt_tokens(&self) -> usize {
self.prompt_token_count.unwrap_or(0) as usize
}
fn total_tokens(&self) -> usize {
self.total_token_count.unwrap_or(0) as usize
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
pub safety_ratings: Option<Vec<SafetyRating>>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: String,
pub probability: String,
pub blocked: Option<bool>,
}
// ============================================================================
// PROVIDER REQUEST TRAIT IMPLEMENTATION
// ============================================================================
impl ProviderRequest for GenerateContentRequest {
fn model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
// Gemini uses URL-based streaming, not a field in the request body
false
}
fn extract_messages_text(&self) -> String {
let mut parts_text = Vec::new();
for content in &self.contents {
for part in &content.parts {
if let Some(text) = &part.text {
parts_text.push(text.clone());
}
}
}
if let Some(system) = &self.system_instruction {
for part in &system.parts {
if let Some(text) = &part.text {
parts_text.push(text.clone());
}
}
}
parts_text.join(" ")
}
fn get_recent_user_message(&self) -> Option<String> {
self.contents
.iter()
.rev()
.find(|c| c.role.as_deref() == Some("user"))
.and_then(|c| {
c.parts
.iter()
.filter_map(|p| p.text.clone())
.collect::<Vec<_>>()
.first()
.cloned()
})
}
fn get_tool_names(&self) -> Option<Vec<String>> {
self.tools.as_ref().map(|tools| {
tools
.iter()
.filter_map(|t| t.function_declarations.as_ref())
.flatten()
.map(|f| f.name.clone())
.collect()
})
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize GenerateContentRequest: {}", e),
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
&self.metadata
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
fn get_temperature(&self) -> Option<f32> {
self.generation_config
.as_ref()
.and_then(|gc| gc.temperature)
}
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
use crate::apis::openai::{Message, MessageContent, Role};
let mut messages = Vec::new();
// Convert system instruction
if let Some(system) = &self.system_instruction {
let text = system
.parts
.iter()
.filter_map(|p| p.text.clone())
.collect::<Vec<_>>()
.join("");
if !text.is_empty() {
messages.push(Message {
role: Role::System,
content: Some(MessageContent::Text(text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
// Convert contents
for content in &self.contents {
let role = match content.role.as_deref() {
Some("model") => Role::Assistant,
_ => Role::User,
};
let text = content
.parts
.iter()
.filter_map(|p| p.text.clone())
.collect::<Vec<_>>()
.join("");
messages.push(Message {
role,
content: Some(MessageContent::Text(text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
messages
}
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
use crate::apis::openai::Role;
self.contents.clear();
self.system_instruction = None;
for msg in messages {
let text = msg.content.extract_text();
match msg.role {
Role::System => {
self.system_instruction = Some(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
Role::User => {
self.contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
Role::Assistant => {
self.contents.push(Content {
role: Some("model".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
Role::Tool => {
self.contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
}
}
}
}
// ============================================================================
// PROVIDER STREAM RESPONSE TRAIT IMPLEMENTATION
// ============================================================================
impl ProviderStreamResponse for GenerateContentResponse {
fn content_delta(&self) -> Option<&str> {
self.candidates
.as_ref()
.and_then(|candidates| candidates.first())
.and_then(|candidate| candidate.content.as_ref())
.and_then(|content| content.parts.first())
.and_then(|part| part.text.as_deref())
}
fn is_final(&self) -> bool {
self.candidates
.as_ref()
.and_then(|candidates| candidates.first())
.and_then(|candidate| candidate.finish_reason.as_deref())
.map(|reason| reason == "STOP" || reason == "MAX_TOKENS" || reason == "SAFETY")
.unwrap_or(false)
}
fn role(&self) -> Option<&str> {
self.candidates
.as_ref()
.and_then(|candidates| candidates.first())
.and_then(|candidate| candidate.content.as_ref())
.and_then(|content| content.role.as_deref())
}
fn event_type(&self) -> Option<&str> {
None // Gemini doesn't use SSE event types
}
}
// ============================================================================
// SERDE PARSING
// ============================================================================
impl TryFrom<&[u8]> for GenerateContentRequest {
type Error = serde_json::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes)
}
}
impl TryFrom<&[u8]> for GenerateContentResponse {
type Error = serde_json::Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes)
}
}
// ============================================================================
// TESTS
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_gemini_api_from_endpoint() {
assert_eq!(
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:generateContent"),
Some(GeminiApi::GenerateContent)
);
assert_eq!(
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:streamGenerateContent"),
Some(GeminiApi::StreamGenerateContent)
);
assert_eq!(GeminiApi::from_endpoint("/v1/chat/completions"), None);
}
#[test]
fn test_generate_content_request_serde() {
let json_str = json!({
"contents": [{
"role": "user",
"parts": [{"text": "Hello"}]
}],
"generationConfig": {
"temperature": 0.7,
"maxOutputTokens": 1024
}
});
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
assert_eq!(req.contents.len(), 1);
assert_eq!(req.contents[0].role, Some("user".to_string()));
assert_eq!(
req.generation_config.as_ref().unwrap().temperature,
Some(0.7)
);
assert_eq!(
req.generation_config.as_ref().unwrap().max_output_tokens,
Some(1024)
);
// Roundtrip
let bytes = serde_json::to_vec(&req).unwrap();
let req2: GenerateContentRequest = serde_json::from_slice(&bytes).unwrap();
assert_eq!(req2.contents.len(), 1);
}
#[test]
fn test_generate_content_response_serde() {
let json_str = json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "Hello! How can I help?"}]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 7,
"totalTokenCount": 12
}
});
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
assert!(resp.candidates.is_some());
let candidates = resp.candidates.as_ref().unwrap();
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
assert_eq!(
resp.usage_metadata.as_ref().unwrap().prompt_token_count,
Some(5)
);
}
#[test]
fn test_generate_content_request_with_tools() {
let json_str = json!({
"contents": [{
"role": "user",
"parts": [{"text": "What's the weather?"}]
}],
"tools": [{
"functionDeclarations": [{
"name": "get_weather",
"description": "Get weather info",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}]
}],
"toolConfig": {
"functionCallingConfig": {
"mode": "AUTO"
}
}
});
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
assert!(req.tools.is_some());
let tools = req.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let decls = tools[0].function_declarations.as_ref().unwrap();
assert_eq!(decls[0].name, "get_weather");
assert_eq!(
req.tool_config
.as_ref()
.unwrap()
.function_calling_config
.mode,
"AUTO"
);
}
#[test]
fn test_generate_content_response_with_function_call() {
let json_str = json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": "get_weather",
"args": {"location": "NYC"}
}
}]
},
"finishReason": "STOP"
}]
});
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
let candidates = resp.candidates.as_ref().unwrap();
let parts = &candidates[0].content.as_ref().unwrap().parts;
assert!(parts[0].function_call.is_some());
assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather");
}
#[test]
fn test_stream_response_content_delta() {
let resp = GenerateContentResponse {
candidates: Some(vec![Candidate {
content: Some(Content {
role: Some("model".to_string()),
parts: vec![Part {
text: Some("Hello".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}),
finish_reason: None,
safety_ratings: None,
}]),
usage_metadata: None,
prompt_feedback: None,
model_version: None,
};
assert_eq!(resp.content_delta(), Some("Hello"));
assert!(!resp.is_final());
}
#[test]
fn test_stream_response_is_final() {
let resp = GenerateContentResponse {
candidates: Some(vec![Candidate {
content: Some(Content {
role: Some("model".to_string()),
parts: vec![Part {
text: Some("Done".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}),
finish_reason: Some("STOP".to_string()),
safety_ratings: None,
}]),
usage_metadata: None,
prompt_feedback: None,
model_version: None,
};
assert!(resp.is_final());
}
#[test]
fn test_provider_request_extract_text() {
let req = GenerateContentRequest {
model: "gemini-pro".to_string(),
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("Hello world".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}],
system_instruction: Some(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("Be helpful".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}),
..Default::default()
};
let text = req.extract_messages_text();
assert!(text.contains("Hello world"));
assert!(text.contains("Be helpful"));
}
#[test]
fn test_provider_request_get_tool_names() {
let req = GenerateContentRequest {
model: "gemini-pro".to_string(),
contents: vec![],
tools: Some(vec![Tool {
function_declarations: Some(vec![
FunctionDeclaration {
name: "func_a".to_string(),
description: None,
parameters: None,
},
FunctionDeclaration {
name: "func_b".to_string(),
description: None,
parameters: None,
},
]),
code_execution: None,
}]),
..Default::default()
};
let names = req.get_tool_names().unwrap();
assert_eq!(names, vec!["func_a", "func_b"]);
}
}

View file

@ -1,5 +1,6 @@
pub mod amazon_bedrock;
pub mod anthropic;
pub mod gemini;
pub mod openai;
pub mod openai_responses;
pub mod streaming_shapes;
@ -10,6 +11,7 @@ pub use amazon_bedrock::{
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
};
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use gemini::{GeminiApi, GenerateContentRequest, GenerateContentResponse};
pub use openai::{
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
};

View file

@ -1,4 +1,4 @@
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, GeminiApi, OpenAIApi};
use crate::ProviderId;
use std::fmt;
@ -8,6 +8,7 @@ pub enum SupportedAPIsFromClient {
OpenAIChatCompletions(OpenAIApi),
AnthropicMessagesAPI(AnthropicApi),
OpenAIResponsesAPI(OpenAIApi),
GeminiGenerateContentAPI(GeminiApi),
}
#[derive(Debug, Clone, PartialEq)]
@ -17,6 +18,8 @@ pub enum SupportedUpstreamAPIs {
AmazonBedrockConverse(AmazonBedrockApi),
AmazonBedrockConverseStream(AmazonBedrockApi),
OpenAIResponsesAPI(OpenAIApi),
GeminiGenerateContent(GeminiApi),
GeminiStreamGenerateContent(GeminiApi),
}
impl fmt::Display for SupportedAPIsFromClient {
@ -31,6 +34,9 @@ impl fmt::Display for SupportedAPIsFromClient {
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
write!(f, "OpenAI Responses ({})", api.endpoint())
}
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => {
write!(f, "Gemini ({})", api.endpoint())
}
}
}
}
@ -53,6 +59,12 @@ impl fmt::Display for SupportedUpstreamAPIs {
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
write!(f, "OpenAI Responses ({})", api.endpoint())
}
SupportedUpstreamAPIs::GeminiGenerateContent(api) => {
write!(f, "Gemini ({})", api.endpoint())
}
SupportedUpstreamAPIs::GeminiStreamGenerateContent(api) => {
write!(f, "Gemini Stream ({})", api.endpoint())
}
}
}
}
@ -60,6 +72,13 @@ impl fmt::Display for SupportedUpstreamAPIs {
impl SupportedAPIsFromClient {
/// Create a SupportedApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
// Check Gemini first since it uses suffix matching (`:generateContent`)
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
return Some(SupportedAPIsFromClient::GeminiGenerateContentAPI(
gemini_api,
));
}
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
// Check if this is the Responses API endpoint
if openai_api == OpenAIApi::Responses {
@ -82,6 +101,7 @@ impl SupportedAPIsFromClient {
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => api.endpoint(),
}
}
@ -145,7 +165,18 @@ impl SupportedAPIsFromClient {
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
build_endpoint("/v1beta/openai", endpoint_suffix)
// Use native Gemini endpoint
if !is_streaming {
build_endpoint(
"/v1beta",
&format!("/models/{}:generateContent", model_id),
)
} else {
build_endpoint(
"/v1beta",
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
)
}
} else {
build_endpoint("/v1", endpoint_suffix)
}
@ -178,6 +209,20 @@ impl SupportedAPIsFromClient {
build_endpoint("/v1", "/chat/completions")
}
}
ProviderId::Gemini => {
// Translate Anthropic → Gemini native
if !is_streaming {
build_endpoint(
"/v1beta",
&format!("/models/{}:generateContent", model_id),
)
} else {
build_endpoint(
"/v1beta",
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
)
}
}
_ => build_endpoint("/v1", "/chat/completions"),
}
}
@ -186,6 +231,20 @@ impl SupportedAPIsFromClient {
match provider_id {
// Providers that support /v1/responses natively
ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"),
ProviderId::Gemini => {
// Translate Responses → Gemini native
if !is_streaming {
build_endpoint(
"/v1beta",
&format!("/models/{}:generateContent", model_id),
)
} else {
build_endpoint(
"/v1beta",
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
)
}
}
// All other providers: translate to /chat/completions
_ => route_by_provider("/chat/completions"),
}
@ -194,6 +253,33 @@ impl SupportedAPIsFromClient {
// For Chat Completions API, use the standard chat/completions path
route_by_provider("/chat/completions")
}
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
match provider_id {
ProviderId::Gemini => {
// Native Gemini endpoint
if !is_streaming {
build_endpoint(
"/v1beta",
&format!("/models/{}:generateContent", model_id),
)
} else {
build_endpoint(
"/v1beta",
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
)
}
}
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
ProviderId::AmazonBedrock => {
if !is_streaming {
build_endpoint("", &format!("/model/{}/converse", model_id))
} else {
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
}
}
_ => build_endpoint("/v1", "/chat/completions"),
}
}
}
}
}
@ -201,6 +287,18 @@ impl SupportedAPIsFromClient {
impl SupportedUpstreamAPIs {
/// Create a SupportedUpstreamApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
// Check Gemini first since it uses suffix matching
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
return match gemini_api {
GeminiApi::GenerateContent => {
Some(SupportedUpstreamAPIs::GeminiGenerateContent(gemini_api))
}
GeminiApi::StreamGenerateContent => Some(
SupportedUpstreamAPIs::GeminiStreamGenerateContent(gemini_api),
),
};
}
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
// Check if this is the Responses API endpoint
if openai_api == OpenAIApi::Responses {
@ -396,7 +494,7 @@ mod tests {
"/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview"
);
// Test Gemini provider
// Test Gemini provider (uses native Gemini API with transforms)
assert_eq!(
api.target_endpoint_for_provider(
&ProviderId::Gemini,
@ -405,7 +503,7 @@ mod tests {
false,
None
),
"/v1beta/openai/chat/completions"
"/v1beta/models/gemini-pro:generateContent"
);
}

View file

@ -20,6 +20,7 @@ pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamRe
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
pub const MESSAGES_PATH: &str = "/v1/messages";
pub const GENERATE_CONTENT_PATH_SUFFIX: &str = ":generateContent";
#[cfg(test)]
mod tests {

View file

@ -1,4 +1,4 @@
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
use crate::apis::{AmazonBedrockApi, AnthropicApi, GeminiApi, OpenAIApi};
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use serde::Deserialize;
use std::collections::HashMap;
@ -116,7 +116,68 @@ impl ProviderId {
is_streaming: bool,
) -> SupportedUpstreamAPIs {
match (self, client_api) {
// ============================================================================
// Gemini provider — use native Gemini APIs
// ============================================================================
(ProviderId::Gemini, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
GeminiApi::StreamGenerateContent,
)
} else {
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
}
}
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
if is_streaming {
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
GeminiApi::StreamGenerateContent,
)
} else {
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
}
}
(ProviderId::Gemini, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
GeminiApi::StreamGenerateContent,
)
} else {
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
}
}
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
GeminiApi::StreamGenerateContent,
)
} else {
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
}
}
// ============================================================================
// Non-Gemini providers receiving Gemini-format requests
// ============================================================================
(ProviderId::Anthropic, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
(_, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
}
// ============================================================================
// Claude/Anthropic providers natively support Anthropic APIs
// ============================================================================
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
@ -136,7 +197,6 @@ impl ProviderId {
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
@ -154,7 +214,6 @@ impl ProviderId {
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI

View file

@ -1,5 +1,7 @@
use crate::apis::anthropic::MessagesRequest;
use crate::apis::gemini::GenerateContentRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::ApiDefinition;
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::apis::openai_responses::ResponsesAPIRequest;
@ -19,7 +21,8 @@ pub enum ProviderRequestType {
BedrockConverse(ConverseRequest),
BedrockConverseStream(ConverseStreamRequest),
ResponsesAPIRequest(ResponsesAPIRequest),
//add more request types here
GeminiGenerateContent(GenerateContentRequest),
GeminiStreamGenerateContent(GenerateContentRequest),
}
pub trait ProviderRequest: Send + Sync {
/// Extract the model name from the request
@ -69,6 +72,9 @@ impl ProviderRequestType {
Self::BedrockConverse(r) => r.set_messages(messages),
Self::BedrockConverseStream(r) => r.set_messages(messages),
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.set_messages(messages)
}
}
}
@ -100,6 +106,7 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.model(),
Self::BedrockConverseStream(r) => r.model(),
Self::ResponsesAPIRequest(r) => r.model(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.model(),
}
}
@ -110,6 +117,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.set_model(model),
Self::BedrockConverseStream(r) => r.set_model(model),
Self::ResponsesAPIRequest(r) => r.set_model(model),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.set_model(model)
}
}
}
@ -120,6 +130,8 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(_) => false,
Self::BedrockConverseStream(_) => true,
Self::ResponsesAPIRequest(r) => r.is_streaming(),
Self::GeminiGenerateContent(_) => false,
Self::GeminiStreamGenerateContent(_) => true,
}
}
@ -130,6 +142,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.extract_messages_text(),
Self::BedrockConverseStream(r) => r.extract_messages_text(),
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.extract_messages_text()
}
}
}
@ -140,6 +155,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.get_recent_user_message(),
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.get_recent_user_message()
}
}
}
@ -150,6 +168,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.get_tool_names(),
Self::BedrockConverseStream(r) => r.get_tool_names(),
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.get_tool_names()
}
}
}
@ -160,6 +181,7 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.to_bytes(),
Self::BedrockConverseStream(r) => r.to_bytes(),
Self::ResponsesAPIRequest(r) => r.to_bytes(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.to_bytes(),
}
}
@ -170,6 +192,7 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.metadata(),
Self::BedrockConverseStream(r) => r.metadata(),
Self::ResponsesAPIRequest(r) => r.metadata(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.metadata(),
}
}
@ -180,6 +203,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.remove_metadata_key(key),
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.remove_metadata_key(key)
}
}
}
@ -190,6 +216,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.get_temperature(),
Self::BedrockConverseStream(r) => r.get_temperature(),
Self::ResponsesAPIRequest(r) => r.get_temperature(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.get_temperature()
}
}
}
@ -200,6 +229,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.get_messages(),
Self::BedrockConverseStream(r) => r.get_messages(),
Self::ResponsesAPIRequest(r) => r.get_messages(),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.get_messages()
}
}
}
@ -210,6 +242,9 @@ impl ProviderRequest for ProviderRequestType {
Self::BedrockConverse(r) => r.set_messages(messages),
Self::BedrockConverseStream(r) => r.set_messages(messages),
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
r.set_messages(messages)
}
}
}
}
@ -245,6 +280,18 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
responses_apirequest,
))
}
SupportedAPIsFromClient::GeminiGenerateContentAPI(gemini_api) => {
let gemini_request: GenerateContentRequest =
GenerateContentRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if gemini_api.supports_streaming() {
Ok(ProviderRequestType::GeminiStreamGenerateContent(
gemini_request,
))
} else {
Ok(ProviderRequestType::GeminiGenerateContent(gemini_request))
}
}
}
}
}
@ -309,6 +356,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
source: None,
})
}
// ChatCompletions -> Gemini
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::GeminiGenerateContent(_),
) => {
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
) => {
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
}
// ============================================================================
// MessagesRequest conversions
@ -370,6 +448,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
source: None,
})
}
// Messages -> Gemini (chain: Anthropic -> OpenAI -> Gemini)
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::GeminiGenerateContent(_),
) => {
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to GenerateContentRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
}
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
) => {
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to GenerateContentRequest (stream): {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
}
// ============================================================================
// ResponsesAPIRequest conversions (only converts TO other formats)
@ -480,6 +589,171 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
}
// ResponsesAPI -> Gemini (via ChatCompletions)
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::GeminiGenerateContent(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
}
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
}
// ============================================================================
// GeminiGenerateContent conversions (client sends Gemini format)
// ============================================================================
(
ProviderRequestType::GeminiGenerateContent(gemini_req),
SupportedUpstreamAPIs::GeminiGenerateContent(_),
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
(
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
// Cross-streaming mode: non-streaming -> streaming and vice versa
(
ProviderRequestType::GeminiGenerateContent(gemini_req),
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
(
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::GeminiGenerateContent(_),
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
(
ProviderRequestType::GeminiGenerateContent(gemini_req)
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
(
ProviderRequestType::GeminiGenerateContent(gemini_req)
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
let messages_req = MessagesRequest::try_from(gemini_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert GenerateContentRequest to MessagesRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(
ProviderRequestType::GeminiGenerateContent(gemini_req)
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
// Chain: Gemini -> OpenAI -> Bedrock
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to ConverseRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::GeminiGenerateContent(gemini_req)
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
// Chain: Gemini -> OpenAI -> Bedrock Stream
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to ConverseStreamRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
}
(
ProviderRequestType::GeminiGenerateContent(_)
| ProviderRequestType::GeminiStreamGenerateContent(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => {
Err(ProviderRequestError {
message: "Conversion from GenerateContentRequest to ResponsesAPIRequest is not supported.".to_string(),
source: None,
})
}
// ============================================================================
// Amazon Bedrock conversions (not supported as client API)
// ============================================================================

View file

@ -1,5 +1,6 @@
use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::gemini::GenerateContentResponse;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai_responses::ResponsesAPIResponse;
use crate::clients::endpoints::SupportedAPIsFromClient;
@ -16,6 +17,7 @@ pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse),
MessagesResponse(MessagesResponse),
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
GenerateContentResponse(GenerateContentResponse),
}
/// Trait for token usage information
@ -44,6 +46,9 @@ impl ProviderResponse for ProviderResponseType {
ProviderResponseType::ResponsesAPIResponse(resp) => {
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
}
ProviderResponseType::GenerateContentResponse(resp) => {
resp.usage_metadata.as_ref().map(|u| u as &dyn TokenUsage)
}
}
}
@ -58,6 +63,15 @@ impl ProviderResponse for ProviderResponseType {
u.total_tokens as usize,
)
}),
ProviderResponseType::GenerateContentResponse(resp) => {
resp.usage_metadata.as_ref().map(|u| {
(
u.prompt_token_count.unwrap_or(0) as usize,
u.candidates_token_count.unwrap_or(0) as usize,
u.total_token_count.unwrap_or(0) as usize,
)
})
}
}
}
}
@ -238,6 +252,140 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
response_api,
)))
}
// ============================================================================
// Gemini upstream transformations
// ============================================================================
(
SupportedUpstreamAPIs::GeminiGenerateContent(_),
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
) => {
// Passthrough: Gemini upstream -> Gemini client
let resp: GenerateContentResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::GenerateContentResponse(resp))
}
(
SupportedUpstreamAPIs::GeminiGenerateContent(_),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => {
// Gemini upstream -> OpenAI client
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(
SupportedUpstreamAPIs::GeminiGenerateContent(_),
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
) => {
// Chain: Gemini -> OpenAI -> Anthropic
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let messages_resp: MessagesResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
(
SupportedUpstreamAPIs::GeminiGenerateContent(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
// Chain: Gemini -> OpenAI -> ResponsesAPI
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let responses_resp: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
responses_resp,
)))
}
// ============================================================================
// Non-Gemini upstream -> Gemini client
// ============================================================================
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
) => {
// OpenAI upstream -> Gemini client
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let gemini_resp: GenerateContentResponse = openai_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
}
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
) => {
// Chain: Anthropic -> OpenAI -> Gemini
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
}
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
) => {
// Chain: Bedrock -> OpenAI -> Gemini
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unsupported API combination for response transformation",

View file

@ -83,6 +83,11 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBu
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
}
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
// Gemini client with a different upstream - use passthrough
// since Gemini streaming uses SSE and doesn't need special buffering
Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()))
}
}
}
}

View file

@ -15,6 +15,7 @@ pub mod response_streaming;
// Re-export commonly used items for convenience
pub use lib::*;
#[allow(ambiguous_glob_reexports)]
pub use request::*;
pub use response::*;
pub use response_streaming::*;

View file

@ -0,0 +1,327 @@
use crate::apis::gemini::GenerateContentRequest;
use crate::apis::openai::{
ChatCompletionsRequest, Function, FunctionCall as OpenAIFunctionCall, Message, MessageContent,
Role, Tool, ToolCall as OpenAIToolCall, ToolChoice, ToolChoiceType,
};
use crate::apis::anthropic::MessagesRequest;
use crate::clients::TransformError;
// ============================================================================
// Gemini GenerateContent -> OpenAI ChatCompletions
// ============================================================================
impl TryFrom<GenerateContentRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
let mut messages: Vec<Message> = Vec::new();
// Convert system instruction
if let Some(system) = &req.system_instruction {
let text = system
.parts
.iter()
.filter_map(|p| p.text.clone())
.collect::<Vec<_>>()
.join("");
if !text.is_empty() {
messages.push(Message {
role: Role::System,
content: Some(MessageContent::Text(text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
// Convert contents
for content in &req.contents {
let role = match content.role.as_deref() {
Some("model") => Role::Assistant,
_ => Role::User,
};
// Check if this content has function_call parts (assistant with tool calls)
let has_function_calls = content.parts.iter().any(|p| p.function_call.is_some());
let has_function_responses =
content.parts.iter().any(|p| p.function_response.is_some());
if has_function_calls {
// Convert to assistant message with tool_calls
let mut tool_calls = Vec::new();
let mut text_parts = Vec::new();
for (i, part) in content.parts.iter().enumerate() {
if let Some(fc) = &part.function_call {
tool_calls.push(OpenAIToolCall {
id: format!("call_{}", i),
call_type: "function".to_string(),
function: OpenAIFunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
},
});
} else if let Some(text) = &part.text {
text_parts.push(text.clone());
}
}
let content_text = if text_parts.is_empty() {
None
} else {
Some(MessageContent::Text(text_parts.join("")))
};
messages.push(Message {
role: Role::Assistant,
content: content_text,
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
});
} else if has_function_responses {
// Convert each function_response to a tool message
for part in &content.parts {
if let Some(fr) = &part.function_response {
let result_text = serde_json::to_string(&fr.response).unwrap_or_default();
messages.push(Message {
role: Role::Tool,
content: Some(MessageContent::Text(result_text)),
name: None,
tool_calls: None,
tool_call_id: Some(fr.name.clone()),
});
}
}
} else {
// Regular text message
let text = content
.parts
.iter()
.filter_map(|p| p.text.clone())
.collect::<Vec<_>>()
.join("");
messages.push(Message {
role,
content: Some(MessageContent::Text(text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
// Convert generation config
let (temperature, top_p, max_tokens, stop, presence_penalty, frequency_penalty) =
if let Some(gc) = &req.generation_config {
(
gc.temperature,
gc.top_p,
gc.max_output_tokens,
gc.stop_sequences.clone(),
gc.presence_penalty,
gc.frequency_penalty,
)
} else {
(None, None, None, None, None, None)
};
// Convert tools
let tools = req.tools.and_then(|gemini_tools| {
let openai_tools: Vec<Tool> = gemini_tools
.iter()
.filter_map(|t| t.function_declarations.as_ref())
.flatten()
.map(|fd| Tool {
tool_type: "function".to_string(),
function: Function {
name: fd.name.clone(),
description: fd.description.clone(),
parameters: fd.parameters.clone().unwrap_or_default(),
strict: None,
},
})
.collect();
if openai_tools.is_empty() {
None
} else {
Some(openai_tools)
}
});
// Convert tool_config
let tool_choice =
req.tool_config
.and_then(|tc| match tc.function_calling_config.mode.as_str() {
"AUTO" => Some(ToolChoice::Type(ToolChoiceType::Auto)),
"NONE" => Some(ToolChoice::Type(ToolChoiceType::None)),
"ANY" => Some(ToolChoice::Type(ToolChoiceType::Required)),
_ => None,
});
Ok(ChatCompletionsRequest {
model: req.model,
messages,
temperature,
top_p,
max_completion_tokens: max_tokens,
stop,
tools,
tool_choice,
presence_penalty,
frequency_penalty,
metadata: req.metadata,
..Default::default()
})
}
}
// ============================================================================
// Gemini GenerateContent -> Anthropic Messages (via OpenAI)
// ============================================================================
impl TryFrom<GenerateContentRequest> for MessagesRequest {
type Error = TransformError;
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
// Chain: Gemini -> OpenAI -> Anthropic
let chat_req = ChatCompletionsRequest::try_from(req)?;
MessagesRequest::try_from(chat_req)
}
}
// ============================================================================
// TESTS
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::gemini::{Content, FunctionCall, Part};
use serde_json::json;
#[test]
fn test_gemini_to_openai_basic() {
let req = GenerateContentRequest {
model: "gemini-pro".to_string(),
contents: vec![
Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("Hello".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
},
Content {
role: Some("model".to_string()),
parts: vec![Part {
text: Some("Hi there!".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
},
],
system_instruction: Some(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("Be helpful".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}),
generation_config: Some(crate::apis::gemini::GenerationConfig {
temperature: Some(0.5),
max_output_tokens: Some(512),
..Default::default()
}),
..Default::default()
};
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
// System + user + assistant = 3 messages
assert_eq!(openai_req.messages.len(), 3);
assert_eq!(openai_req.messages[0].role, Role::System);
assert_eq!(openai_req.messages[1].role, Role::User);
assert_eq!(openai_req.messages[2].role, Role::Assistant);
assert_eq!(openai_req.temperature, Some(0.5));
assert_eq!(openai_req.max_completion_tokens, Some(512));
}
#[test]
fn test_gemini_to_openai_with_function_calls() {
let req = GenerateContentRequest {
model: "gemini-pro".to_string(),
contents: vec![
Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("Weather?".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
},
Content {
role: Some("model".to_string()),
parts: vec![Part {
text: None,
inline_data: None,
function_call: Some(FunctionCall {
name: "get_weather".to_string(),
args: json!({"location": "NYC"}),
}),
function_response: None,
}],
},
],
..Default::default()
};
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
assert_eq!(openai_req.messages.len(), 2);
assert!(openai_req.messages[1].tool_calls.is_some());
let tc = openai_req.messages[1].tool_calls.as_ref().unwrap();
assert_eq!(tc[0].function.name, "get_weather");
}
#[test]
fn test_gemini_to_openai_tool_config() {
let req = GenerateContentRequest {
model: "gemini-pro".to_string(),
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some("test".to_string()),
inline_data: None,
function_call: None,
function_response: None,
}],
}],
tool_config: Some(crate::apis::gemini::ToolConfig {
function_calling_config: crate::apis::gemini::FunctionCallingConfig {
mode: "ANY".to_string(),
},
}),
..Default::default()
};
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
assert!(openai_req.tool_choice.is_some());
assert_eq!(
openai_req.tool_choice.as_ref().unwrap(),
&ToolChoice::Type(ToolChoiceType::Required)
);
}
}

View file

@ -1,4 +1,6 @@
//! Request transformation modules
pub mod from_anthropic;
pub mod from_gemini;
pub mod from_openai;
pub mod to_gemini;

View file

@ -0,0 +1,323 @@
use crate::apis::gemini::{
Content, FunctionCall, FunctionCallingConfig, FunctionDeclaration, FunctionResponse,
GenerateContentRequest, GenerationConfig, Part, Tool, ToolConfig,
};
use crate::apis::openai::{ChatCompletionsRequest, Role, ToolChoice, ToolChoiceType};
use crate::apis::anthropic::MessagesRequest;
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText;
// ============================================================================
// OpenAI ChatCompletions -> Gemini GenerateContent
// ============================================================================
impl TryFrom<ChatCompletionsRequest> for GenerateContentRequest {
type Error = TransformError;
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
let mut contents: Vec<Content> = Vec::new();
let mut system_instruction: Option<Content> = None;
for msg in &req.messages {
match msg.role {
Role::System => {
let text = msg.content.extract_text();
system_instruction = Some(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
Role::User => {
let text = msg.content.extract_text();
contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
Role::Assistant => {
let mut parts = Vec::new();
// Check for tool calls
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
let args: serde_json::Value =
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
parts.push(Part {
text: None,
inline_data: None,
function_call: Some(FunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
});
}
}
// Also include text content if present
let text = msg.content.extract_text();
if !text.is_empty() {
parts.push(Part {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
});
}
if !parts.is_empty() {
contents.push(Content {
role: Some("model".to_string()),
parts,
});
}
}
Role::Tool => {
let text = msg.content.extract_text();
let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
let response_value = serde_json::from_str(&text)
.unwrap_or_else(|_| serde_json::json!({"result": text}));
contents.push(Content {
role: Some("user".to_string()),
parts: vec![Part {
text: None,
inline_data: None,
function_call: None,
function_response: Some(FunctionResponse {
name: tool_call_id,
response: response_value,
}),
}],
});
}
}
}
// Convert generation config
let generation_config = {
let gc = GenerationConfig {
temperature: req.temperature,
top_p: req.top_p,
top_k: None,
max_output_tokens: req.max_completion_tokens.or(req.max_tokens),
stop_sequences: req.stop,
response_mime_type: None,
candidate_count: None,
presence_penalty: req.presence_penalty,
frequency_penalty: req.frequency_penalty,
};
// Only include if any field is set
if gc.temperature.is_some()
|| gc.top_p.is_some()
|| gc.max_output_tokens.is_some()
|| gc.stop_sequences.is_some()
|| gc.presence_penalty.is_some()
|| gc.frequency_penalty.is_some()
{
Some(gc)
} else {
None
}
};
// Convert tools
let tools = req.tools.map(|openai_tools| {
let declarations: Vec<FunctionDeclaration> = openai_tools
.iter()
.map(|t| FunctionDeclaration {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: Some(t.function.parameters.clone()),
})
.collect();
vec![Tool {
function_declarations: Some(declarations),
code_execution: None,
}]
});
// Convert tool_choice
let tool_config = req.tool_choice.and_then(|tc| {
let mode = match tc {
ToolChoice::Type(t) => match t {
ToolChoiceType::Auto => Some("AUTO".to_string()),
ToolChoiceType::None => Some("NONE".to_string()),
ToolChoiceType::Required => Some("ANY".to_string()),
},
ToolChoice::Function { .. } => Some("AUTO".to_string()),
};
mode.map(|m| ToolConfig {
function_calling_config: FunctionCallingConfig { mode: m },
})
});
Ok(GenerateContentRequest {
model: req.model,
contents,
generation_config,
tools,
tool_config,
safety_settings: None,
system_instruction,
cached_content: None,
metadata: req.metadata,
})
}
}
// ============================================================================
// Anthropic Messages -> Gemini GenerateContent (via OpenAI)
// ============================================================================
impl TryFrom<MessagesRequest> for GenerateContentRequest {
type Error = TransformError;
fn try_from(req: MessagesRequest) -> Result<Self, Self::Error> {
// Chain: Anthropic -> OpenAI -> Gemini
let chat_req = ChatCompletionsRequest::try_from(req)?;
GenerateContentRequest::try_from(chat_req)
}
}
// ============================================================================
// TESTS
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_openai_to_gemini_basic() {
let req: ChatCompletionsRequest = serde_json::from_value(json!({
"model": "gemini-pro",
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
],
"temperature": 0.7,
"max_tokens": 1024
}))
.unwrap();
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
// System should be in system_instruction
assert!(gemini_req.system_instruction.is_some());
let sys = gemini_req.system_instruction.as_ref().unwrap();
assert_eq!(sys.parts[0].text.as_deref(), Some("You are helpful"));
// 3 content messages (user, model, user)
assert_eq!(gemini_req.contents.len(), 3);
assert_eq!(gemini_req.contents[0].role.as_deref(), Some("user"));
assert_eq!(gemini_req.contents[1].role.as_deref(), Some("model"));
assert_eq!(gemini_req.contents[2].role.as_deref(), Some("user"));
// Generation config
assert_eq!(
gemini_req.generation_config.as_ref().unwrap().temperature,
Some(0.7)
);
assert_eq!(
gemini_req
.generation_config
.as_ref()
.unwrap()
.max_output_tokens,
Some(1024)
);
}
#[test]
fn test_openai_to_gemini_with_tools() {
let req: ChatCompletionsRequest = serde_json::from_value(json!({
"model": "gemini-pro",
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
}
}],
"tool_choice": "auto"
}))
.unwrap();
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
assert!(gemini_req.tools.is_some());
let tools = gemini_req.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let decls = tools[0].function_declarations.as_ref().unwrap();
assert_eq!(decls[0].name, "get_weather");
assert!(gemini_req.tool_config.is_some());
assert_eq!(
gemini_req
.tool_config
.as_ref()
.unwrap()
.function_calling_config
.mode,
"AUTO"
);
}
#[test]
fn test_openai_to_gemini_with_tool_calls() {
let req: ChatCompletionsRequest = serde_json::from_value(json!({
"model": "gemini-pro",
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"NYC\"}"
}
}]
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": "Sunny, 72F"
}
]
}))
.unwrap();
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
assert_eq!(gemini_req.contents.len(), 3);
// Assistant with function_call
let model_content = &gemini_req.contents[1];
assert_eq!(model_content.role.as_deref(), Some("model"));
assert!(model_content.parts[0].function_call.is_some());
// Tool response
let tool_content = &gemini_req.contents[2];
assert_eq!(tool_content.role.as_deref(), Some("user"));
assert!(tool_content.parts[0].function_response.is_some());
}
}

View file

@ -0,0 +1,417 @@
use crate::apis::anthropic::MessagesResponse;
use crate::apis::gemini::GenerateContentResponse;
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
FunctionCall as OpenAIFunctionCall, MessageDelta, ResponseMessage, Role, StreamChoice,
ToolCall as OpenAIToolCall, Usage,
};
use crate::clients::TransformError;
// ============================================================================
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsResponse
// ============================================================================
fn map_finish_reason(gemini_reason: Option<&str>) -> Option<FinishReason> {
gemini_reason.map(|r| match r {
"STOP" => FinishReason::Stop,
"MAX_TOKENS" => FinishReason::Length,
"SAFETY" | "RECITATION" => FinishReason::ContentFilter,
_ => FinishReason::Stop,
})
}
impl TryFrom<GenerateContentResponse> for ChatCompletionsResponse {
type Error = TransformError;
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
let candidates = resp.candidates.unwrap_or_default();
let candidate = candidates.first();
let mut content_text = String::new();
let mut tool_calls: Vec<OpenAIToolCall> = Vec::new();
if let Some(candidate) = candidate {
if let Some(ref content) = candidate.content {
for (i, part) in content.parts.iter().enumerate() {
if let Some(ref text) = part.text {
content_text.push_str(text);
}
if let Some(ref fc) = part.function_call {
tool_calls.push(OpenAIToolCall {
id: format!("call_{}", i),
call_type: "function".to_string(),
function: OpenAIFunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
},
});
}
}
}
}
let finish_reason = candidate
.and_then(|c| map_finish_reason(c.finish_reason.as_deref()))
.unwrap_or(FinishReason::Stop);
let message_content = if content_text.is_empty() {
None
} else {
Some(content_text)
};
let tool_calls_opt = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
let choice = Choice {
index: 0,
message: ResponseMessage {
role: Role::Assistant,
content: message_content,
tool_calls: tool_calls_opt,
refusal: None,
annotations: None,
audio: None,
function_call: None,
},
finish_reason: Some(finish_reason),
logprobs: None,
};
let usage = resp
.usage_metadata
.map(|um| Usage {
prompt_tokens: um.prompt_token_count.unwrap_or(0),
completion_tokens: um.candidates_token_count.unwrap_or(0),
total_tokens: um.total_token_count.unwrap_or(0),
prompt_tokens_details: None,
completion_tokens_details: None,
})
.unwrap_or_default();
Ok(ChatCompletionsResponse {
id: format!(
"chatcmpl-gemini-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
),
object: Some("chat.completion".to_string()),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
choices: vec![choice],
usage,
system_fingerprint: None,
service_tier: None,
metadata: None,
})
}
}
// ============================================================================
// Gemini GenerateContentResponse -> Anthropic MessagesResponse (via OpenAI)
// ============================================================================
impl TryFrom<GenerateContentResponse> for MessagesResponse {
type Error = TransformError;
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
// Chain: Gemini -> OpenAI -> Anthropic
let chat_resp = ChatCompletionsResponse::try_from(resp)?;
MessagesResponse::try_from(chat_resp)
}
}
// ============================================================================
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsStreamResponse
// ============================================================================
impl TryFrom<GenerateContentResponse> for ChatCompletionsStreamResponse {
type Error = TransformError;
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
let candidates = resp.candidates.unwrap_or_default();
let candidate = candidates.first();
let mut delta_content: Option<String> = None;
if let Some(candidate) = candidate {
if let Some(ref content) = candidate.content {
let mut text_parts = Vec::new();
for part in content.parts.iter() {
if let Some(ref text) = part.text {
text_parts.push(text.clone());
}
}
if !text_parts.is_empty() {
delta_content = Some(text_parts.join(""));
}
}
}
let finish_reason = candidate.and_then(|c| map_finish_reason(c.finish_reason.as_deref()));
let role = candidate
.and_then(|c| c.content.as_ref())
.and_then(|c| c.role.as_deref())
.map(|r| match r {
"model" => Role::Assistant,
_ => Role::User,
});
Ok(ChatCompletionsStreamResponse {
id: format!(
"chatcmpl-gemini-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
),
object: Some("chat.completion.chunk".to_string()),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role,
content: delta_content,
tool_calls: None,
refusal: None,
function_call: None,
},
finish_reason,
logprobs: None,
}],
usage: None,
system_fingerprint: None,
service_tier: None,
})
}
}
// ============================================================================
// REVERSE: OpenAI ChatCompletionsResponse -> Gemini GenerateContentResponse
// ============================================================================
impl TryFrom<ChatCompletionsResponse> for GenerateContentResponse {
type Error = TransformError;
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
use crate::apis::gemini::{Candidate, Content, FunctionCall, Part, UsageMetadata};
let candidates = if let Some(choice) = resp.choices.first() {
let mut parts = Vec::new();
// Text content
if let Some(ref content) = choice.message.content {
if !content.is_empty() {
parts.push(Part {
text: Some(content.clone()),
inline_data: None,
function_call: None,
function_response: None,
});
}
}
// Tool calls
if let Some(ref tool_calls) = choice.message.tool_calls {
for tc in tool_calls {
let args: serde_json::Value =
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
parts.push(Part {
text: None,
inline_data: None,
function_call: Some(FunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
});
}
}
if parts.is_empty() {
parts.push(Part {
text: Some(String::new()),
inline_data: None,
function_call: None,
function_response: None,
});
}
let finish_reason = choice.finish_reason.as_ref().map(|fr| match fr {
FinishReason::Stop => "STOP".to_string(),
FinishReason::Length => "MAX_TOKENS".to_string(),
FinishReason::ContentFilter => "SAFETY".to_string(),
FinishReason::ToolCalls => "STOP".to_string(),
FinishReason::FunctionCall => "STOP".to_string(),
});
vec![Candidate {
content: Some(Content {
role: Some("model".to_string()),
parts,
}),
finish_reason,
safety_ratings: None,
}]
} else {
vec![]
};
let usage_metadata = Some(UsageMetadata {
prompt_token_count: Some(resp.usage.prompt_tokens),
candidates_token_count: Some(resp.usage.completion_tokens),
total_token_count: Some(resp.usage.total_tokens),
});
Ok(GenerateContentResponse {
candidates: Some(candidates),
usage_metadata,
prompt_feedback: None,
model_version: Some(resp.model),
})
}
}
// ============================================================================
// TESTS
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_gemini_to_openai_response() {
let resp: GenerateContentResponse = serde_json::from_value(json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "Hello! How can I help?"}]
},
"finishReason": "STOP"
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 7,
"totalTokenCount": 12
},
"modelVersion": "gemini-2.0-flash"
}))
.unwrap();
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
assert_eq!(openai_resp.choices.len(), 1);
let msg = &openai_resp.choices[0].message;
assert_eq!(msg.content.as_deref(), Some("Hello! How can I help?"));
assert_eq!(
openai_resp.choices[0].finish_reason,
Some(FinishReason::Stop)
);
assert_eq!(openai_resp.usage.prompt_tokens, 5);
assert_eq!(openai_resp.usage.completion_tokens, 7);
}
#[test]
fn test_gemini_to_openai_stream_response() {
let resp: GenerateContentResponse = serde_json::from_value(json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{"text": "Hello"}]
}
}]
}))
.unwrap();
let stream_resp = ChatCompletionsStreamResponse::try_from(resp).unwrap();
assert_eq!(stream_resp.choices.len(), 1);
assert_eq!(
stream_resp.choices[0].delta.content,
Some("Hello".to_string())
);
assert_eq!(stream_resp.choices[0].delta.role, Some(Role::Assistant));
}
#[test]
fn test_gemini_to_openai_with_function_call() {
let resp: GenerateContentResponse = serde_json::from_value(json!({
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": "get_weather",
"args": {"location": "NYC"}
}
}]
},
"finishReason": "STOP"
}]
}))
.unwrap();
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
let msg = &openai_resp.choices[0].message;
assert!(msg.tool_calls.is_some());
let tc = msg.tool_calls.as_ref().unwrap();
assert_eq!(tc[0].function.name, "get_weather");
}
#[test]
fn test_openai_to_gemini_response() {
let resp: ChatCompletionsResponse = serde_json::from_value(json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}
}))
.unwrap();
let gemini_resp = GenerateContentResponse::try_from(resp).unwrap();
let candidates = gemini_resp.candidates.as_ref().unwrap();
assert_eq!(candidates.len(), 1);
let parts = &candidates[0].content.as_ref().unwrap().parts;
assert_eq!(parts[0].text.as_deref(), Some("Hello!"));
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
}
#[test]
fn test_finish_reason_mapping() {
assert_eq!(map_finish_reason(Some("STOP")), Some(FinishReason::Stop));
assert_eq!(
map_finish_reason(Some("MAX_TOKENS")),
Some(FinishReason::Length)
);
assert_eq!(
map_finish_reason(Some("SAFETY")),
Some(FinishReason::ContentFilter)
);
assert_eq!(
map_finish_reason(Some("RECITATION")),
Some(FinishReason::ContentFilter)
);
assert_eq!(map_finish_reason(None), None);
}
}

View file

@ -1,4 +1,5 @@
//! Response transformation modules
pub mod from_gemini;
pub mod output_to_input;
pub mod to_anthropic;
pub mod to_openai;

View file

@ -217,7 +217,9 @@ impl StreamContext {
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_)
| SupportedUpstreamAPIs::GeminiGenerateContent(_)
| SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
)
| None => {
// OpenAI and default: use Authorization Bearer token