mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
add native Gemini provider support via hermesllm transforms
This commit is contained in:
parent
5400b0a2fa
commit
053108b96c
16 changed files with 2416 additions and 10 deletions
|
|
@ -53,7 +53,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_)
|
||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
||||
| ProviderRequestType::ResponsesAPIRequest(_)
|
||||
| ProviderRequestType::GeminiGenerateContent(_)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(_),
|
||||
) => {
|
||||
warn!("unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||
return Err(RoutingError::internal_error(
|
||||
|
|
|
|||
744
crates/hermesllm/src/apis/gemini.rs
Normal file
744
crates/hermesllm/src/apis/gemini.rs
Normal file
|
|
@ -0,0 +1,744 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::TokenUsage;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::GENERATE_CONTENT_PATH_SUFFIX;
|
||||
|
||||
// ============================================================================
|
||||
// GEMINI GENERATE CONTENT API ENUMERATION
|
||||
// ============================================================================
|
||||
|
||||
/// Enum for all supported Gemini GenerateContent APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GeminiApi {
|
||||
GenerateContent,
|
||||
StreamGenerateContent,
|
||||
}
|
||||
|
||||
impl ApiDefinition for GeminiApi {
|
||||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
GeminiApi::GenerateContent => ":generateContent",
|
||||
GeminiApi::StreamGenerateContent => ":streamGenerateContent",
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if endpoint.ends_with(":streamGenerateContent") {
|
||||
Some(GeminiApi::StreamGenerateContent)
|
||||
} else if endpoint.ends_with(GENERATE_CONTENT_PATH_SUFFIX) {
|
||||
Some(GeminiApi::GenerateContent)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
match self {
|
||||
GeminiApi::GenerateContent => false,
|
||||
GeminiApi::StreamGenerateContent => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![GeminiApi::GenerateContent, GeminiApi::StreamGenerateContent]
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// REQUEST TYPES
|
||||
// ============================================================================
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
/// Internal model field — not part of Gemini wire format (model is in the URL).
|
||||
/// Populated during parsing and used for routing.
|
||||
#[serde(skip_serializing, default)]
|
||||
pub model: String,
|
||||
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
pub tool_config: Option<ToolConfig>,
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
pub system_instruction: Option<Content>,
|
||||
pub cached_content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing)]
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Content {
|
||||
pub role: Option<String>,
|
||||
pub parts: Vec<Part>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Part {
|
||||
pub text: Option<String>,
|
||||
pub inline_data: Option<InlineData>,
|
||||
pub function_call: Option<FunctionCall>,
|
||||
pub function_response: Option<FunctionResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InlineData {
|
||||
pub mime_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub args: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionResponse {
|
||||
pub name: String,
|
||||
pub response: Value,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<u32>,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub response_mime_type: Option<String>,
|
||||
pub candidate_count: Option<u32>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
pub code_execution: Option<Value>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolConfig {
|
||||
pub function_calling_config: FunctionCallingConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCallingConfig {
|
||||
pub mode: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
pub category: String,
|
||||
pub threshold: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RESPONSE TYPES
|
||||
// ============================================================================
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponse {
|
||||
pub candidates: Option<Vec<Candidate>>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
pub prompt_feedback: Option<PromptFeedback>,
|
||||
pub model_version: Option<String>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
pub content: Option<Content>,
|
||||
pub finish_reason: Option<String>,
|
||||
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub prompt_token_count: Option<u32>,
|
||||
pub candidates_token_count: Option<u32>,
|
||||
pub total_token_count: Option<u32>,
|
||||
}
|
||||
|
||||
impl TokenUsage for UsageMetadata {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.candidates_token_count.unwrap_or(0) as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_token_count.unwrap_or(0) as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_token_count.unwrap_or(0) as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptFeedback {
|
||||
pub block_reason: Option<String>,
|
||||
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
pub blocked: Option<bool>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PROVIDER REQUEST TRAIT IMPLEMENTATION
|
||||
// ============================================================================
|
||||
|
||||
impl ProviderRequest for GenerateContentRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
// Gemini uses URL-based streaming, not a field in the request body
|
||||
false
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
let mut parts_text = Vec::new();
|
||||
for content in &self.contents {
|
||||
for part in &content.parts {
|
||||
if let Some(text) = &part.text {
|
||||
parts_text.push(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(system) = &self.system_instruction {
|
||||
for part in &system.parts {
|
||||
if let Some(text) = &part.text {
|
||||
parts_text.push(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
parts_text.join(" ")
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.contents
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|c| c.role.as_deref() == Some("user"))
|
||||
.and_then(|c| {
|
||||
c.parts
|
||||
.iter()
|
||||
.filter_map(|p| p.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.first()
|
||||
.cloned()
|
||||
})
|
||||
}
|
||||
|
||||
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||
self.tools.as_ref().map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|t| t.function_declarations.as_ref())
|
||||
.flatten()
|
||||
.map(|f| f.name.clone())
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize GenerateContentRequest: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
if let Some(ref mut metadata) = self.metadata {
|
||||
metadata.remove(key).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.generation_config
|
||||
.as_ref()
|
||||
.and_then(|gc| gc.temperature)
|
||||
}
|
||||
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// Convert system instruction
|
||||
if let Some(system) = &self.system_instruction {
|
||||
let text = system
|
||||
.parts
|
||||
.iter()
|
||||
.filter_map(|p| p.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
if !text.is_empty() {
|
||||
messages.push(Message {
|
||||
role: Role::System,
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert contents
|
||||
for content in &self.contents {
|
||||
let role = match content.role.as_deref() {
|
||||
Some("model") => Role::Assistant,
|
||||
_ => Role::User,
|
||||
};
|
||||
|
||||
let text = content
|
||||
.parts
|
||||
.iter()
|
||||
.filter_map(|p| p.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
messages.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
use crate::apis::openai::Role;
|
||||
|
||||
self.contents.clear();
|
||||
self.system_instruction = None;
|
||||
|
||||
for msg in messages {
|
||||
let text = msg.content.extract_text();
|
||||
match msg.role {
|
||||
Role::System => {
|
||||
self.system_instruction = Some(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Role::User => {
|
||||
self.contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Role::Assistant => {
|
||||
self.contents.push(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Role::Tool => {
|
||||
self.contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PROVIDER STREAM RESPONSE TRAIT IMPLEMENTATION
|
||||
// ============================================================================
|
||||
|
||||
impl ProviderStreamResponse for GenerateContentResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
self.candidates
|
||||
.as_ref()
|
||||
.and_then(|candidates| candidates.first())
|
||||
.and_then(|candidate| candidate.content.as_ref())
|
||||
.and_then(|content| content.parts.first())
|
||||
.and_then(|part| part.text.as_deref())
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
self.candidates
|
||||
.as_ref()
|
||||
.and_then(|candidates| candidates.first())
|
||||
.and_then(|candidate| candidate.finish_reason.as_deref())
|
||||
.map(|reason| reason == "STOP" || reason == "MAX_TOKENS" || reason == "SAFETY")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.candidates
|
||||
.as_ref()
|
||||
.and_then(|candidates| candidates.first())
|
||||
.and_then(|candidate| candidate.content.as_ref())
|
||||
.and_then(|content| content.role.as_deref())
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
None // Gemini doesn't use SSE event types
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SERDE PARSING
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<&[u8]> for GenerateContentRequest {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for GenerateContentResponse {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_gemini_api_from_endpoint() {
|
||||
assert_eq!(
|
||||
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:generateContent"),
|
||||
Some(GeminiApi::GenerateContent)
|
||||
);
|
||||
assert_eq!(
|
||||
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:streamGenerateContent"),
|
||||
Some(GeminiApi::StreamGenerateContent)
|
||||
);
|
||||
assert_eq!(GeminiApi::from_endpoint("/v1/chat/completions"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_content_request_serde() {
|
||||
let json_str = json!({
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": [{"text": "Hello"}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 0.7,
|
||||
"maxOutputTokens": 1024
|
||||
}
|
||||
});
|
||||
|
||||
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
|
||||
assert_eq!(req.contents.len(), 1);
|
||||
assert_eq!(req.contents[0].role, Some("user".to_string()));
|
||||
assert_eq!(
|
||||
req.generation_config.as_ref().unwrap().temperature,
|
||||
Some(0.7)
|
||||
);
|
||||
assert_eq!(
|
||||
req.generation_config.as_ref().unwrap().max_output_tokens,
|
||||
Some(1024)
|
||||
);
|
||||
|
||||
// Roundtrip
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let req2: GenerateContentRequest = serde_json::from_slice(&bytes).unwrap();
|
||||
assert_eq!(req2.contents.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_content_response_serde() {
|
||||
let json_str = json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{"text": "Hello! How can I help?"}]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 12
|
||||
}
|
||||
});
|
||||
|
||||
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
|
||||
assert!(resp.candidates.is_some());
|
||||
let candidates = resp.candidates.as_ref().unwrap();
|
||||
assert_eq!(candidates.len(), 1);
|
||||
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
|
||||
assert_eq!(
|
||||
resp.usage_metadata.as_ref().unwrap().prompt_token_count,
|
||||
Some(5)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_content_request_with_tools() {
|
||||
let json_str = json!({
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": [{"text": "What's the weather?"}]
|
||||
}],
|
||||
"tools": [{
|
||||
"functionDeclarations": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}],
|
||||
"toolConfig": {
|
||||
"functionCallingConfig": {
|
||||
"mode": "AUTO"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
|
||||
assert!(req.tools.is_some());
|
||||
let tools = req.tools.as_ref().unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
let decls = tools[0].function_declarations.as_ref().unwrap();
|
||||
assert_eq!(decls[0].name, "get_weather");
|
||||
assert_eq!(
|
||||
req.tool_config
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.function_calling_config
|
||||
.mode,
|
||||
"AUTO"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_content_response_with_function_call() {
|
||||
let json_str = json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"functionCall": {
|
||||
"name": "get_weather",
|
||||
"args": {"location": "NYC"}
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}]
|
||||
});
|
||||
|
||||
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
|
||||
let candidates = resp.candidates.as_ref().unwrap();
|
||||
let parts = &candidates[0].content.as_ref().unwrap().parts;
|
||||
assert!(parts[0].function_call.is_some());
|
||||
assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_response_content_delta() {
|
||||
let resp = GenerateContentResponse {
|
||||
candidates: Some(vec![Candidate {
|
||||
content: Some(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Hello".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}),
|
||||
finish_reason: None,
|
||||
safety_ratings: None,
|
||||
}]),
|
||||
usage_metadata: None,
|
||||
prompt_feedback: None,
|
||||
model_version: None,
|
||||
};
|
||||
|
||||
assert_eq!(resp.content_delta(), Some("Hello"));
|
||||
assert!(!resp.is_final());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_response_is_final() {
|
||||
let resp = GenerateContentResponse {
|
||||
candidates: Some(vec![Candidate {
|
||||
content: Some(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Done".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}),
|
||||
finish_reason: Some("STOP".to_string()),
|
||||
safety_ratings: None,
|
||||
}]),
|
||||
usage_metadata: None,
|
||||
prompt_feedback: None,
|
||||
model_version: None,
|
||||
};
|
||||
|
||||
assert!(resp.is_final());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_request_extract_text() {
|
||||
let req = GenerateContentRequest {
|
||||
model: "gemini-pro".to_string(),
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Hello world".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}],
|
||||
system_instruction: Some(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Be helpful".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let text = req.extract_messages_text();
|
||||
assert!(text.contains("Hello world"));
|
||||
assert!(text.contains("Be helpful"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_request_get_tool_names() {
|
||||
let req = GenerateContentRequest {
|
||||
model: "gemini-pro".to_string(),
|
||||
contents: vec![],
|
||||
tools: Some(vec![Tool {
|
||||
function_declarations: Some(vec![
|
||||
FunctionDeclaration {
|
||||
name: "func_a".to_string(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
},
|
||||
FunctionDeclaration {
|
||||
name: "func_b".to_string(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
},
|
||||
]),
|
||||
code_execution: None,
|
||||
}]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let names = req.get_tool_names().unwrap();
|
||||
assert_eq!(names, vec!["func_a", "func_b"]);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod amazon_bedrock;
|
||||
pub mod anthropic;
|
||||
pub mod gemini;
|
||||
pub mod openai;
|
||||
pub mod openai_responses;
|
||||
pub mod streaming_shapes;
|
||||
|
|
@ -10,6 +11,7 @@ pub use amazon_bedrock::{
|
|||
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
||||
pub use gemini::{GeminiApi, GenerateContentRequest, GenerateContentResponse};
|
||||
pub use openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, GeminiApi, OpenAIApi};
|
||||
use crate::ProviderId;
|
||||
use std::fmt;
|
||||
|
||||
|
|
@ -8,6 +8,7 @@ pub enum SupportedAPIsFromClient {
|
|||
OpenAIChatCompletions(OpenAIApi),
|
||||
AnthropicMessagesAPI(AnthropicApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
GeminiGenerateContentAPI(GeminiApi),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -17,6 +18,8 @@ pub enum SupportedUpstreamAPIs {
|
|||
AmazonBedrockConverse(AmazonBedrockApi),
|
||||
AmazonBedrockConverseStream(AmazonBedrockApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
GeminiGenerateContent(GeminiApi),
|
||||
GeminiStreamGenerateContent(GeminiApi),
|
||||
}
|
||||
|
||||
impl fmt::Display for SupportedAPIsFromClient {
|
||||
|
|
@ -31,6 +34,9 @@ impl fmt::Display for SupportedAPIsFromClient {
|
|||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => {
|
||||
write!(f, "Gemini ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -53,6 +59,12 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
|||
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(api) => {
|
||||
write!(f, "Gemini ({})", api.endpoint())
|
||||
}
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(api) => {
|
||||
write!(f, "Gemini Stream ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -60,6 +72,13 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
|||
impl SupportedAPIsFromClient {
|
||||
/// Create a SupportedApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
// Check Gemini first since it uses suffix matching (`:generateContent`)
|
||||
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIsFromClient::GeminiGenerateContentAPI(
|
||||
gemini_api,
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
// Check if this is the Responses API endpoint
|
||||
if openai_api == OpenAIApi::Responses {
|
||||
|
|
@ -82,6 +101,7 @@ impl SupportedAPIsFromClient {
|
|||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => api.endpoint(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -145,7 +165,18 @@ impl SupportedAPIsFromClient {
|
|||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/v1beta/openai", endpoint_suffix)
|
||||
// Use native Gemini endpoint
|
||||
if !is_streaming {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:generateContent", model_id),
|
||||
)
|
||||
} else {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
|
|
@ -178,6 +209,20 @@ impl SupportedAPIsFromClient {
|
|||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
// Translate Anthropic → Gemini native
|
||||
if !is_streaming {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:generateContent", model_id),
|
||||
)
|
||||
} else {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
}
|
||||
}
|
||||
|
|
@ -186,6 +231,20 @@ impl SupportedAPIsFromClient {
|
|||
match provider_id {
|
||||
// Providers that support /v1/responses natively
|
||||
ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"),
|
||||
ProviderId::Gemini => {
|
||||
// Translate Responses → Gemini native
|
||||
if !is_streaming {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:generateContent", model_id),
|
||||
)
|
||||
} else {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||
)
|
||||
}
|
||||
}
|
||||
// All other providers: translate to /chat/completions
|
||||
_ => route_by_provider("/chat/completions"),
|
||||
}
|
||||
|
|
@ -194,6 +253,33 @@ impl SupportedAPIsFromClient {
|
|||
// For Chat Completions API, use the standard chat/completions path
|
||||
route_by_provider("/chat/completions")
|
||||
}
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
|
||||
match provider_id {
|
||||
ProviderId::Gemini => {
|
||||
// Native Gemini endpoint
|
||||
if !is_streaming {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:generateContent", model_id),
|
||||
)
|
||||
} else {
|
||||
build_endpoint(
|
||||
"/v1beta",
|
||||
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||
)
|
||||
}
|
||||
}
|
||||
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -201,6 +287,18 @@ impl SupportedAPIsFromClient {
|
|||
impl SupportedUpstreamAPIs {
|
||||
/// Create a SupportedUpstreamApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
// Check Gemini first since it uses suffix matching
|
||||
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
|
||||
return match gemini_api {
|
||||
GeminiApi::GenerateContent => {
|
||||
Some(SupportedUpstreamAPIs::GeminiGenerateContent(gemini_api))
|
||||
}
|
||||
GeminiApi::StreamGenerateContent => Some(
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(gemini_api),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
// Check if this is the Responses API endpoint
|
||||
if openai_api == OpenAIApi::Responses {
|
||||
|
|
@ -396,7 +494,7 @@ mod tests {
|
|||
"/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview"
|
||||
);
|
||||
|
||||
// Test Gemini provider
|
||||
// Test Gemini provider (uses native Gemini API with transforms)
|
||||
assert_eq!(
|
||||
api.target_endpoint_for_provider(
|
||||
&ProviderId::Gemini,
|
||||
|
|
@ -405,7 +503,7 @@ mod tests {
|
|||
false,
|
||||
None
|
||||
),
|
||||
"/v1beta/openai/chat/completions"
|
||||
"/v1beta/models/gemini-pro:generateContent"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamRe
|
|||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
pub const GENERATE_CONTENT_PATH_SUFFIX: &str = ":generateContent";
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, GeminiApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -116,7 +116,68 @@ impl ProviderId {
|
|||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// ============================================================================
|
||||
// Gemini provider — use native Gemini APIs
|
||||
// ============================================================================
|
||||
(ProviderId::Gemini, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||
GeminiApi::StreamGenerateContent,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||
}
|
||||
}
|
||||
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||
GeminiApi::StreamGenerateContent,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||
}
|
||||
}
|
||||
(ProviderId::Gemini, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||
GeminiApi::StreamGenerateContent,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||
}
|
||||
}
|
||||
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||
GeminiApi::StreamGenerateContent,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Non-Gemini providers receiving Gemini-format requests
|
||||
// ============================================================================
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(_, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
// ============================================================================
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
|
|
@ -136,7 +197,6 @@ impl ProviderId {
|
|||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
|
|
@ -154,7 +214,6 @@ impl ProviderId {
|
|||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::apis::gemini::GenerateContentRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::ApiDefinition;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::apis::openai_responses::ResponsesAPIRequest;
|
||||
|
|
@ -19,7 +21,8 @@ pub enum ProviderRequestType {
|
|||
BedrockConverse(ConverseRequest),
|
||||
BedrockConverseStream(ConverseStreamRequest),
|
||||
ResponsesAPIRequest(ResponsesAPIRequest),
|
||||
//add more request types here
|
||||
GeminiGenerateContent(GenerateContentRequest),
|
||||
GeminiStreamGenerateContent(GenerateContentRequest),
|
||||
}
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
/// Extract the model name from the request
|
||||
|
|
@ -69,6 +72,9 @@ impl ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.set_messages(messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +106,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.model(),
|
||||
Self::BedrockConverseStream(r) => r.model(),
|
||||
Self::ResponsesAPIRequest(r) => r.model(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -110,6 +117,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.set_model(model),
|
||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||
Self::ResponsesAPIRequest(r) => r.set_model(model),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.set_model(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -120,6 +130,8 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(_) => false,
|
||||
Self::BedrockConverseStream(_) => true,
|
||||
Self::ResponsesAPIRequest(r) => r.is_streaming(),
|
||||
Self::GeminiGenerateContent(_) => false,
|
||||
Self::GeminiStreamGenerateContent(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -130,6 +142,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.extract_messages_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,6 +155,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.get_recent_user_message()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -150,6 +168,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.get_tool_names(),
|
||||
Self::BedrockConverseStream(r) => r.get_tool_names(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.get_tool_names()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -160,6 +181,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.to_bytes(),
|
||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||
Self::ResponsesAPIRequest(r) => r.to_bytes(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -170,6 +192,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.metadata(),
|
||||
Self::BedrockConverseStream(r) => r.metadata(),
|
||||
Self::ResponsesAPIRequest(r) => r.metadata(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -180,6 +203,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.remove_metadata_key(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -190,6 +216,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.get_temperature(),
|
||||
Self::BedrockConverseStream(r) => r.get_temperature(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_temperature(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.get_temperature()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -200,6 +229,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.get_messages(),
|
||||
Self::BedrockConverseStream(r) => r.get_messages(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_messages(),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.get_messages()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -210,6 +242,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||
r.set_messages(messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -245,6 +280,18 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
|
|||
responses_apirequest,
|
||||
))
|
||||
}
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(gemini_api) => {
|
||||
let gemini_request: GenerateContentRequest =
|
||||
GenerateContentRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
if gemini_api.supports_streaming() {
|
||||
Ok(ProviderRequestType::GeminiStreamGenerateContent(
|
||||
gemini_request,
|
||||
))
|
||||
} else {
|
||||
Ok(ProviderRequestType::GeminiGenerateContent(gemini_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -309,6 +356,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
source: None,
|
||||
})
|
||||
}
|
||||
// ChatCompletions -> Gemini
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
) => {
|
||||
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
) => {
|
||||
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MessagesRequest conversions
|
||||
|
|
@ -370,6 +448,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
source: None,
|
||||
})
|
||||
}
|
||||
// Messages -> Gemini (chain: Anthropic -> OpenAI -> Gemini)
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
) => {
|
||||
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to GenerateContentRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
) => {
|
||||
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to GenerateContentRequest (stream): {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ResponsesAPIRequest conversions (only converts TO other formats)
|
||||
|
|
@ -480,6 +589,171 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
|
||||
// ResponsesAPI -> Gemini (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GeminiGenerateContent conversions (client sends Gemini format)
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
|
||||
(
|
||||
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
|
||||
// Cross-streaming mode: non-streaming -> streaming and vice versa
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
|
||||
(
|
||||
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let messages_req = MessagesRequest::try_from(gemini_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert GenerateContentRequest to MessagesRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
// Chain: Gemini -> OpenAI -> Bedrock
|
||||
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to ConverseRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
// Chain: Gemini -> OpenAI -> Bedrock Stream
|
||||
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to ConverseStreamRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::GeminiGenerateContent(_)
|
||||
| ProviderRequestType::GeminiStreamGenerateContent(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
Err(ProviderRequestError {
|
||||
message: "Conversion from GenerateContentRequest to ResponsesAPIRequest is not supported.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Amazon Bedrock conversions (not supported as client API)
|
||||
// ============================================================================
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use crate::apis::amazon_bedrock::ConverseResponse;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::gemini::GenerateContentResponse;
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai_responses::ResponsesAPIResponse;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
|
|
@ -16,6 +17,7 @@ pub enum ProviderResponseType {
|
|||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
MessagesResponse(MessagesResponse),
|
||||
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
|
||||
GenerateContentResponse(GenerateContentResponse),
|
||||
}
|
||||
|
||||
/// Trait for token usage information
|
||||
|
|
@ -44,6 +46,9 @@ impl ProviderResponse for ProviderResponseType {
|
|||
ProviderResponseType::ResponsesAPIResponse(resp) => {
|
||||
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
|
||||
}
|
||||
ProviderResponseType::GenerateContentResponse(resp) => {
|
||||
resp.usage_metadata.as_ref().map(|u| u as &dyn TokenUsage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -58,6 +63,15 @@ impl ProviderResponse for ProviderResponseType {
|
|||
u.total_tokens as usize,
|
||||
)
|
||||
}),
|
||||
ProviderResponseType::GenerateContentResponse(resp) => {
|
||||
resp.usage_metadata.as_ref().map(|u| {
|
||||
(
|
||||
u.prompt_token_count.unwrap_or(0) as usize,
|
||||
u.candidates_token_count.unwrap_or(0) as usize,
|
||||
u.total_token_count.unwrap_or(0) as usize,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -238,6 +252,140 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
|
|||
response_api,
|
||||
)))
|
||||
}
|
||||
// ============================================================================
|
||||
// Gemini upstream transformations
|
||||
// ============================================================================
|
||||
(
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||
) => {
|
||||
// Passthrough: Gemini upstream -> Gemini client
|
||||
let resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::GenerateContentResponse(resp))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
// Gemini upstream -> OpenAI client
|
||||
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
// Chain: Gemini -> OpenAI -> Anthropic
|
||||
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
let messages_resp: MessagesResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
// Chain: Gemini -> OpenAI -> ResponsesAPI
|
||||
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
let responses_resp: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
|
||||
responses_resp,
|
||||
)))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Non-Gemini upstream -> Gemini client
|
||||
// ============================================================================
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||
) => {
|
||||
// OpenAI upstream -> Gemini client
|
||||
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let gemini_resp: GenerateContentResponse = openai_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||
) => {
|
||||
// Chain: Anthropic -> OpenAI -> Gemini
|
||||
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_resp: ChatCompletionsResponse =
|
||||
anthropic_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||
) => {
|
||||
// Chain: Bedrock -> OpenAI -> Gemini
|
||||
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||
}
|
||||
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Unsupported API combination for response transformation",
|
||||
|
|
|
|||
|
|
@ -83,6 +83,11 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBu
|
|||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
|
||||
}
|
||||
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
|
||||
// Gemini client with a different upstream - use passthrough
|
||||
// since Gemini streaming uses SSE and doesn't need special buffering
|
||||
Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ pub mod response_streaming;
|
|||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use lib::*;
|
||||
#[allow(ambiguous_glob_reexports)]
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
pub use response_streaming::*;
|
||||
|
|
|
|||
327
crates/hermesllm/src/transforms/request/from_gemini.rs
Normal file
327
crates/hermesllm/src/transforms/request/from_gemini.rs
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
use crate::apis::gemini::GenerateContentRequest;
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Function, FunctionCall as OpenAIFunctionCall, Message, MessageContent,
|
||||
Role, Tool, ToolCall as OpenAIToolCall, ToolChoice, ToolChoiceType,
|
||||
};
|
||||
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::TransformError;
|
||||
|
||||
// ============================================================================
|
||||
// Gemini GenerateContent -> OpenAI ChatCompletions
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<GenerateContentRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
|
||||
// Convert system instruction
|
||||
if let Some(system) = &req.system_instruction {
|
||||
let text = system
|
||||
.parts
|
||||
.iter()
|
||||
.filter_map(|p| p.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
if !text.is_empty() {
|
||||
messages.push(Message {
|
||||
role: Role::System,
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert contents
|
||||
for content in &req.contents {
|
||||
let role = match content.role.as_deref() {
|
||||
Some("model") => Role::Assistant,
|
||||
_ => Role::User,
|
||||
};
|
||||
|
||||
// Check if this content has function_call parts (assistant with tool calls)
|
||||
let has_function_calls = content.parts.iter().any(|p| p.function_call.is_some());
|
||||
let has_function_responses =
|
||||
content.parts.iter().any(|p| p.function_response.is_some());
|
||||
|
||||
if has_function_calls {
|
||||
// Convert to assistant message with tool_calls
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for (i, part) in content.parts.iter().enumerate() {
|
||||
if let Some(fc) = &part.function_call {
|
||||
tool_calls.push(OpenAIToolCall {
|
||||
id: format!("call_{}", i),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: fc.name.clone(),
|
||||
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
|
||||
},
|
||||
});
|
||||
} else if let Some(text) = &part.text {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let content_text = if text_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(MessageContent::Text(text_parts.join("")))
|
||||
};
|
||||
|
||||
messages.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: content_text,
|
||||
name: None,
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
tool_call_id: None,
|
||||
});
|
||||
} else if has_function_responses {
|
||||
// Convert each function_response to a tool message
|
||||
for part in &content.parts {
|
||||
if let Some(fr) = &part.function_response {
|
||||
let result_text = serde_json::to_string(&fr.response).unwrap_or_default();
|
||||
messages.push(Message {
|
||||
role: Role::Tool,
|
||||
content: Some(MessageContent::Text(result_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(fr.name.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular text message
|
||||
let text = content
|
||||
.parts
|
||||
.iter()
|
||||
.filter_map(|p| p.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
messages.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert generation config
|
||||
let (temperature, top_p, max_tokens, stop, presence_penalty, frequency_penalty) =
|
||||
if let Some(gc) = &req.generation_config {
|
||||
(
|
||||
gc.temperature,
|
||||
gc.top_p,
|
||||
gc.max_output_tokens,
|
||||
gc.stop_sequences.clone(),
|
||||
gc.presence_penalty,
|
||||
gc.frequency_penalty,
|
||||
)
|
||||
} else {
|
||||
(None, None, None, None, None, None)
|
||||
};
|
||||
|
||||
// Convert tools
|
||||
let tools = req.tools.and_then(|gemini_tools| {
|
||||
let openai_tools: Vec<Tool> = gemini_tools
|
||||
.iter()
|
||||
.filter_map(|t| t.function_declarations.as_ref())
|
||||
.flatten()
|
||||
.map(|fd| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: fd.name.clone(),
|
||||
description: fd.description.clone(),
|
||||
parameters: fd.parameters.clone().unwrap_or_default(),
|
||||
strict: None,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
if openai_tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(openai_tools)
|
||||
}
|
||||
});
|
||||
|
||||
// Convert tool_config
|
||||
let tool_choice =
|
||||
req.tool_config
|
||||
.and_then(|tc| match tc.function_calling_config.mode.as_str() {
|
||||
"AUTO" => Some(ToolChoice::Type(ToolChoiceType::Auto)),
|
||||
"NONE" => Some(ToolChoice::Type(ToolChoiceType::None)),
|
||||
"ANY" => Some(ToolChoice::Type(ToolChoiceType::Required)),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
Ok(ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
max_completion_tokens: max_tokens,
|
||||
stop,
|
||||
tools,
|
||||
tool_choice,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
metadata: req.metadata,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Gemini GenerateContent -> Anthropic Messages (via OpenAI)
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<GenerateContentRequest> for MessagesRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
|
||||
// Chain: Gemini -> OpenAI -> Anthropic
|
||||
let chat_req = ChatCompletionsRequest::try_from(req)?;
|
||||
MessagesRequest::try_from(chat_req)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::gemini::{Content, FunctionCall, Part};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_basic() {
|
||||
let req = GenerateContentRequest {
|
||||
model: "gemini-pro".to_string(),
|
||||
contents: vec![
|
||||
Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Hello".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
},
|
||||
Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Hi there!".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
},
|
||||
],
|
||||
system_instruction: Some(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Be helpful".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}),
|
||||
generation_config: Some(crate::apis::gemini::GenerationConfig {
|
||||
temperature: Some(0.5),
|
||||
max_output_tokens: Some(512),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||
|
||||
// System + user + assistant = 3 messages
|
||||
assert_eq!(openai_req.messages.len(), 3);
|
||||
assert_eq!(openai_req.messages[0].role, Role::System);
|
||||
assert_eq!(openai_req.messages[1].role, Role::User);
|
||||
assert_eq!(openai_req.messages[2].role, Role::Assistant);
|
||||
|
||||
assert_eq!(openai_req.temperature, Some(0.5));
|
||||
assert_eq!(openai_req.max_completion_tokens, Some(512));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_with_function_calls() {
|
||||
let req = GenerateContentRequest {
|
||||
model: "gemini-pro".to_string(),
|
||||
contents: vec![
|
||||
Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("Weather?".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
},
|
||||
Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: Some(FunctionCall {
|
||||
name: "get_weather".to_string(),
|
||||
args: json!({"location": "NYC"}),
|
||||
}),
|
||||
function_response: None,
|
||||
}],
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||
assert_eq!(openai_req.messages.len(), 2);
|
||||
assert!(openai_req.messages[1].tool_calls.is_some());
|
||||
let tc = openai_req.messages[1].tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tc[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_tool_config() {
|
||||
let req = GenerateContentRequest {
|
||||
model: "gemini-pro".to_string(),
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some("test".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
}],
|
||||
tool_config: Some(crate::apis::gemini::ToolConfig {
|
||||
function_calling_config: crate::apis::gemini::FunctionCallingConfig {
|
||||
mode: "ANY".to_string(),
|
||||
},
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||
assert!(openai_req.tool_choice.is_some());
|
||||
assert_eq!(
|
||||
openai_req.tool_choice.as_ref().unwrap(),
|
||||
&ToolChoice::Type(ToolChoiceType::Required)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
//! Request transformation modules
|
||||
|
||||
pub mod from_anthropic;
|
||||
pub mod from_gemini;
|
||||
pub mod from_openai;
|
||||
pub mod to_gemini;
|
||||
|
|
|
|||
323
crates/hermesllm/src/transforms/request/to_gemini.rs
Normal file
323
crates/hermesllm/src/transforms/request/to_gemini.rs
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
use crate::apis::gemini::{
|
||||
Content, FunctionCall, FunctionCallingConfig, FunctionDeclaration, FunctionResponse,
|
||||
GenerateContentRequest, GenerationConfig, Part, Tool, ToolConfig,
|
||||
};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Role, ToolChoice, ToolChoiceType};
|
||||
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
|
||||
// ============================================================================
|
||||
// OpenAI ChatCompletions -> Gemini GenerateContent
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for GenerateContentRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
|
||||
let mut contents: Vec<Content> = Vec::new();
|
||||
let mut system_instruction: Option<Content> = None;
|
||||
|
||||
for msg in &req.messages {
|
||||
match msg.role {
|
||||
Role::System => {
|
||||
let text = msg.content.extract_text();
|
||||
system_instruction = Some(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Role::User => {
|
||||
let text = msg.content.extract_text();
|
||||
contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
Role::Assistant => {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Check for tool calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||
parts.push(Part {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: Some(FunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
args,
|
||||
}),
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Also include text content if present
|
||||
let text = msg.content.extract_text();
|
||||
if !text.is_empty() {
|
||||
parts.push(Part {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
|
||||
if !parts.is_empty() {
|
||||
contents.push(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts,
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Tool => {
|
||||
let text = msg.content.extract_text();
|
||||
let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
|
||||
let response_value = serde_json::from_str(&text)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
||||
|
||||
contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: Some(FunctionResponse {
|
||||
name: tool_call_id,
|
||||
response: response_value,
|
||||
}),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert generation config
|
||||
let generation_config = {
|
||||
let gc = GenerationConfig {
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: None,
|
||||
max_output_tokens: req.max_completion_tokens.or(req.max_tokens),
|
||||
stop_sequences: req.stop,
|
||||
response_mime_type: None,
|
||||
candidate_count: None,
|
||||
presence_penalty: req.presence_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
};
|
||||
// Only include if any field is set
|
||||
if gc.temperature.is_some()
|
||||
|| gc.top_p.is_some()
|
||||
|| gc.max_output_tokens.is_some()
|
||||
|| gc.stop_sequences.is_some()
|
||||
|| gc.presence_penalty.is_some()
|
||||
|| gc.frequency_penalty.is_some()
|
||||
{
|
||||
Some(gc)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Convert tools
|
||||
let tools = req.tools.map(|openai_tools| {
|
||||
let declarations: Vec<FunctionDeclaration> = openai_tools
|
||||
.iter()
|
||||
.map(|t| FunctionDeclaration {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: Some(t.function.parameters.clone()),
|
||||
})
|
||||
.collect();
|
||||
vec![Tool {
|
||||
function_declarations: Some(declarations),
|
||||
code_execution: None,
|
||||
}]
|
||||
});
|
||||
|
||||
// Convert tool_choice
|
||||
let tool_config = req.tool_choice.and_then(|tc| {
|
||||
let mode = match tc {
|
||||
ToolChoice::Type(t) => match t {
|
||||
ToolChoiceType::Auto => Some("AUTO".to_string()),
|
||||
ToolChoiceType::None => Some("NONE".to_string()),
|
||||
ToolChoiceType::Required => Some("ANY".to_string()),
|
||||
},
|
||||
ToolChoice::Function { .. } => Some("AUTO".to_string()),
|
||||
};
|
||||
mode.map(|m| ToolConfig {
|
||||
function_calling_config: FunctionCallingConfig { mode: m },
|
||||
})
|
||||
});
|
||||
|
||||
Ok(GenerateContentRequest {
|
||||
model: req.model,
|
||||
contents,
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
safety_settings: None,
|
||||
system_instruction,
|
||||
cached_content: None,
|
||||
metadata: req.metadata,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Anthropic Messages -> Gemini GenerateContent (via OpenAI)
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesRequest> for GenerateContentRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: MessagesRequest) -> Result<Self, Self::Error> {
|
||||
// Chain: Anthropic -> OpenAI -> Gemini
|
||||
let chat_req = ChatCompletionsRequest::try_from(req)?;
|
||||
GenerateContentRequest::try_from(chat_req)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_gemini_basic() {
|
||||
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||
"model": "gemini-pro",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1024
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||
|
||||
// System should be in system_instruction
|
||||
assert!(gemini_req.system_instruction.is_some());
|
||||
let sys = gemini_req.system_instruction.as_ref().unwrap();
|
||||
assert_eq!(sys.parts[0].text.as_deref(), Some("You are helpful"));
|
||||
|
||||
// 3 content messages (user, model, user)
|
||||
assert_eq!(gemini_req.contents.len(), 3);
|
||||
assert_eq!(gemini_req.contents[0].role.as_deref(), Some("user"));
|
||||
assert_eq!(gemini_req.contents[1].role.as_deref(), Some("model"));
|
||||
assert_eq!(gemini_req.contents[2].role.as_deref(), Some("user"));
|
||||
|
||||
// Generation config
|
||||
assert_eq!(
|
||||
gemini_req.generation_config.as_ref().unwrap().temperature,
|
||||
Some(0.7)
|
||||
);
|
||||
assert_eq!(
|
||||
gemini_req
|
||||
.generation_config
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.max_output_tokens,
|
||||
Some(1024)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_gemini_with_tools() {
|
||||
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||
"model": "gemini-pro",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
}],
|
||||
"tool_choice": "auto"
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||
assert!(gemini_req.tools.is_some());
|
||||
let tools = gemini_req.tools.as_ref().unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
let decls = tools[0].function_declarations.as_ref().unwrap();
|
||||
assert_eq!(decls[0].name, "get_weather");
|
||||
|
||||
assert!(gemini_req.tool_config.is_some());
|
||||
assert_eq!(
|
||||
gemini_req
|
||||
.tool_config
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.function_calling_config
|
||||
.mode,
|
||||
"AUTO"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_gemini_with_tool_calls() {
|
||||
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||
"model": "gemini-pro",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"NYC\"}"
|
||||
}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "Sunny, 72F"
|
||||
}
|
||||
]
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||
assert_eq!(gemini_req.contents.len(), 3);
|
||||
|
||||
// Assistant with function_call
|
||||
let model_content = &gemini_req.contents[1];
|
||||
assert_eq!(model_content.role.as_deref(), Some("model"));
|
||||
assert!(model_content.parts[0].function_call.is_some());
|
||||
|
||||
// Tool response
|
||||
let tool_content = &gemini_req.contents[2];
|
||||
assert_eq!(tool_content.role.as_deref(), Some("user"));
|
||||
assert!(tool_content.parts[0].function_response.is_some());
|
||||
}
|
||||
}
|
||||
417
crates/hermesllm/src/transforms/response/from_gemini.rs
Normal file
417
crates/hermesllm/src/transforms/response/from_gemini.rs
Normal file
|
|
@ -0,0 +1,417 @@
|
|||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::gemini::GenerateContentResponse;
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
|
||||
FunctionCall as OpenAIFunctionCall, MessageDelta, ResponseMessage, Role, StreamChoice,
|
||||
ToolCall as OpenAIToolCall, Usage,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
|
||||
// ============================================================================
|
||||
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsResponse
|
||||
// ============================================================================
|
||||
|
||||
fn map_finish_reason(gemini_reason: Option<&str>) -> Option<FinishReason> {
|
||||
gemini_reason.map(|r| match r {
|
||||
"STOP" => FinishReason::Stop,
|
||||
"MAX_TOKENS" => FinishReason::Length,
|
||||
"SAFETY" | "RECITATION" => FinishReason::ContentFilter,
|
||||
_ => FinishReason::Stop,
|
||||
})
|
||||
}
|
||||
|
||||
impl TryFrom<GenerateContentResponse> for ChatCompletionsResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||
let candidates = resp.candidates.unwrap_or_default();
|
||||
let candidate = candidates.first();
|
||||
|
||||
let mut content_text = String::new();
|
||||
let mut tool_calls: Vec<OpenAIToolCall> = Vec::new();
|
||||
|
||||
if let Some(candidate) = candidate {
|
||||
if let Some(ref content) = candidate.content {
|
||||
for (i, part) in content.parts.iter().enumerate() {
|
||||
if let Some(ref text) = part.text {
|
||||
content_text.push_str(text);
|
||||
}
|
||||
if let Some(ref fc) = part.function_call {
|
||||
tool_calls.push(OpenAIToolCall {
|
||||
id: format!("call_{}", i),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: fc.name.clone(),
|
||||
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let finish_reason = candidate
|
||||
.and_then(|c| map_finish_reason(c.finish_reason.as_deref()))
|
||||
.unwrap_or(FinishReason::Stop);
|
||||
|
||||
let message_content = if content_text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content_text)
|
||||
};
|
||||
|
||||
let tool_calls_opt = if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
};
|
||||
|
||||
let choice = Choice {
|
||||
index: 0,
|
||||
message: ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: message_content,
|
||||
tool_calls: tool_calls_opt,
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
},
|
||||
finish_reason: Some(finish_reason),
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
let usage = resp
|
||||
.usage_metadata
|
||||
.map(|um| Usage {
|
||||
prompt_tokens: um.prompt_token_count.unwrap_or(0),
|
||||
completion_tokens: um.candidates_token_count.unwrap_or(0),
|
||||
total_tokens: um.total_token_count.unwrap_or(0),
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(ChatCompletionsResponse {
|
||||
id: format!(
|
||||
"chatcmpl-gemini-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis()
|
||||
),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
|
||||
choices: vec![choice],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Gemini GenerateContentResponse -> Anthropic MessagesResponse (via OpenAI)
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<GenerateContentResponse> for MessagesResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||
// Chain: Gemini -> OpenAI -> Anthropic
|
||||
let chat_resp = ChatCompletionsResponse::try_from(resp)?;
|
||||
MessagesResponse::try_from(chat_resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsStreamResponse
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<GenerateContentResponse> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||
let candidates = resp.candidates.unwrap_or_default();
|
||||
let candidate = candidates.first();
|
||||
|
||||
let mut delta_content: Option<String> = None;
|
||||
|
||||
if let Some(candidate) = candidate {
|
||||
if let Some(ref content) = candidate.content {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for part in content.parts.iter() {
|
||||
if let Some(ref text) = part.text {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !text_parts.is_empty() {
|
||||
delta_content = Some(text_parts.join(""));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let finish_reason = candidate.and_then(|c| map_finish_reason(c.finish_reason.as_deref()));
|
||||
|
||||
let role = candidate
|
||||
.and_then(|c| c.content.as_ref())
|
||||
.and_then(|c| c.role.as_deref())
|
||||
.map(|r| match r {
|
||||
"model" => Role::Assistant,
|
||||
_ => Role::User,
|
||||
});
|
||||
|
||||
Ok(ChatCompletionsStreamResponse {
|
||||
id: format!(
|
||||
"chatcmpl-gemini-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis()
|
||||
),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role,
|
||||
content: delta_content,
|
||||
tool_calls: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
},
|
||||
finish_reason,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// REVERSE: OpenAI ChatCompletionsResponse -> Gemini GenerateContentResponse
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<ChatCompletionsResponse> for GenerateContentResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||
use crate::apis::gemini::{Candidate, Content, FunctionCall, Part, UsageMetadata};
|
||||
|
||||
let candidates = if let Some(choice) = resp.choices.first() {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Text content
|
||||
if let Some(ref content) = choice.message.content {
|
||||
if !content.is_empty() {
|
||||
parts.push(Part {
|
||||
text: Some(content.clone()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls
|
||||
if let Some(ref tool_calls) = choice.message.tool_calls {
|
||||
for tc in tool_calls {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||
parts.push(Part {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: Some(FunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
args,
|
||||
}),
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
parts.push(Part {
|
||||
text: Some(String::new()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
|
||||
let finish_reason = choice.finish_reason.as_ref().map(|fr| match fr {
|
||||
FinishReason::Stop => "STOP".to_string(),
|
||||
FinishReason::Length => "MAX_TOKENS".to_string(),
|
||||
FinishReason::ContentFilter => "SAFETY".to_string(),
|
||||
FinishReason::ToolCalls => "STOP".to_string(),
|
||||
FinishReason::FunctionCall => "STOP".to_string(),
|
||||
});
|
||||
|
||||
vec![Candidate {
|
||||
content: Some(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts,
|
||||
}),
|
||||
finish_reason,
|
||||
safety_ratings: None,
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let usage_metadata = Some(UsageMetadata {
|
||||
prompt_token_count: Some(resp.usage.prompt_tokens),
|
||||
candidates_token_count: Some(resp.usage.completion_tokens),
|
||||
total_token_count: Some(resp.usage.total_tokens),
|
||||
});
|
||||
|
||||
Ok(GenerateContentResponse {
|
||||
candidates: Some(candidates),
|
||||
usage_metadata,
|
||||
prompt_feedback: None,
|
||||
model_version: Some(resp.model),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_response() {
|
||||
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{"text": "Hello! How can I help?"}]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 12
|
||||
},
|
||||
"modelVersion": "gemini-2.0-flash"
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let msg = &openai_resp.choices[0].message;
|
||||
assert_eq!(msg.content.as_deref(), Some("Hello! How can I help?"));
|
||||
assert_eq!(
|
||||
openai_resp.choices[0].finish_reason,
|
||||
Some(FinishReason::Stop)
|
||||
);
|
||||
assert_eq!(openai_resp.usage.prompt_tokens, 5);
|
||||
assert_eq!(openai_resp.usage.completion_tokens, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_stream_response() {
|
||||
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{"text": "Hello"}]
|
||||
}
|
||||
}]
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let stream_resp = ChatCompletionsStreamResponse::try_from(resp).unwrap();
|
||||
assert_eq!(stream_resp.choices.len(), 1);
|
||||
assert_eq!(
|
||||
stream_resp.choices[0].delta.content,
|
||||
Some("Hello".to_string())
|
||||
);
|
||||
assert_eq!(stream_resp.choices[0].delta.role, Some(Role::Assistant));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemini_to_openai_with_function_call() {
|
||||
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"functionCall": {
|
||||
"name": "get_weather",
|
||||
"args": {"location": "NYC"}
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}]
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
|
||||
let msg = &openai_resp.choices[0].message;
|
||||
assert!(msg.tool_calls.is_some());
|
||||
let tc = msg.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tc[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_gemini_response() {
|
||||
let resp: ChatCompletionsResponse = serde_json::from_value(json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let gemini_resp = GenerateContentResponse::try_from(resp).unwrap();
|
||||
let candidates = gemini_resp.candidates.as_ref().unwrap();
|
||||
assert_eq!(candidates.len(), 1);
|
||||
let parts = &candidates[0].content.as_ref().unwrap().parts;
|
||||
assert_eq!(parts[0].text.as_deref(), Some("Hello!"));
|
||||
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finish_reason_mapping() {
|
||||
assert_eq!(map_finish_reason(Some("STOP")), Some(FinishReason::Stop));
|
||||
assert_eq!(
|
||||
map_finish_reason(Some("MAX_TOKENS")),
|
||||
Some(FinishReason::Length)
|
||||
);
|
||||
assert_eq!(
|
||||
map_finish_reason(Some("SAFETY")),
|
||||
Some(FinishReason::ContentFilter)
|
||||
);
|
||||
assert_eq!(
|
||||
map_finish_reason(Some("RECITATION")),
|
||||
Some(FinishReason::ContentFilter)
|
||||
);
|
||||
assert_eq!(map_finish_reason(None), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
//! Response transformation modules
|
||||
pub mod from_gemini;
|
||||
pub mod output_to_input;
|
||||
pub mod to_anthropic;
|
||||
pub mod to_openai;
|
||||
|
|
|
|||
|
|
@ -217,7 +217,9 @@ impl StreamContext {
|
|||
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_)
|
||||
| SupportedUpstreamAPIs::GeminiGenerateContent(_)
|
||||
| SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue