mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
add native Gemini provider support via hermesllm transforms
This commit is contained in:
parent
5400b0a2fa
commit
053108b96c
16 changed files with 2416 additions and 10 deletions
|
|
@ -53,7 +53,9 @@ pub async fn router_chat_get_upstream_model(
|
||||||
ProviderRequestType::MessagesRequest(_)
|
ProviderRequestType::MessagesRequest(_)
|
||||||
| ProviderRequestType::BedrockConverse(_)
|
| ProviderRequestType::BedrockConverse(_)
|
||||||
| ProviderRequestType::BedrockConverseStream(_)
|
| ProviderRequestType::BedrockConverseStream(_)
|
||||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
| ProviderRequestType::ResponsesAPIRequest(_)
|
||||||
|
| ProviderRequestType::GeminiGenerateContent(_)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(_),
|
||||||
) => {
|
) => {
|
||||||
warn!("unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
warn!("unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||||
return Err(RoutingError::internal_error(
|
return Err(RoutingError::internal_error(
|
||||||
|
|
|
||||||
744
crates/hermesllm/src/apis/gemini.rs
Normal file
744
crates/hermesllm/src/apis/gemini.rs
Normal file
|
|
@ -0,0 +1,744 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use serde_with::skip_serializing_none;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::ApiDefinition;
|
||||||
|
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||||
|
use crate::providers::response::TokenUsage;
|
||||||
|
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||||
|
use crate::transforms::lib::ExtractText;
|
||||||
|
use crate::GENERATE_CONTENT_PATH_SUFFIX;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// GEMINI GENERATE CONTENT API ENUMERATION
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Enum for all supported Gemini GenerateContent APIs
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum GeminiApi {
|
||||||
|
GenerateContent,
|
||||||
|
StreamGenerateContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiDefinition for GeminiApi {
|
||||||
|
fn endpoint(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
GeminiApi::GenerateContent => ":generateContent",
|
||||||
|
GeminiApi::StreamGenerateContent => ":streamGenerateContent",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||||
|
if endpoint.ends_with(":streamGenerateContent") {
|
||||||
|
Some(GeminiApi::StreamGenerateContent)
|
||||||
|
} else if endpoint.ends_with(GENERATE_CONTENT_PATH_SUFFIX) {
|
||||||
|
Some(GeminiApi::GenerateContent)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
GeminiApi::GenerateContent => false,
|
||||||
|
GeminiApi::StreamGenerateContent => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_tools(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_vision(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn all_variants() -> Vec<Self> {
|
||||||
|
vec![GeminiApi::GenerateContent, GeminiApi::StreamGenerateContent]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// REQUEST TYPES
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GenerateContentRequest {
|
||||||
|
/// Internal model field — not part of Gemini wire format (model is in the URL).
|
||||||
|
/// Populated during parsing and used for routing.
|
||||||
|
#[serde(skip_serializing, default)]
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
pub contents: Vec<Content>,
|
||||||
|
pub generation_config: Option<GenerationConfig>,
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
pub tool_config: Option<ToolConfig>,
|
||||||
|
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||||
|
pub system_instruction: Option<Content>,
|
||||||
|
pub cached_content: Option<String>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing)]
|
||||||
|
pub metadata: Option<HashMap<String, Value>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Content {
|
||||||
|
pub role: Option<String>,
|
||||||
|
pub parts: Vec<Part>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Part {
|
||||||
|
pub text: Option<String>,
|
||||||
|
pub inline_data: Option<InlineData>,
|
||||||
|
pub function_call: Option<FunctionCall>,
|
||||||
|
pub function_response: Option<FunctionResponse>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct InlineData {
|
||||||
|
pub mime_type: String,
|
||||||
|
pub data: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FunctionCall {
|
||||||
|
pub name: String,
|
||||||
|
pub args: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FunctionResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub response: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GenerationConfig {
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
pub max_output_tokens: Option<u32>,
|
||||||
|
pub stop_sequences: Option<Vec<String>>,
|
||||||
|
pub response_mime_type: Option<String>,
|
||||||
|
pub candidate_count: Option<u32>,
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Tool {
|
||||||
|
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||||
|
pub code_execution: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FunctionDeclaration {
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ToolConfig {
|
||||||
|
pub function_calling_config: FunctionCallingConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FunctionCallingConfig {
|
||||||
|
pub mode: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SafetySetting {
|
||||||
|
pub category: String,
|
||||||
|
pub threshold: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// RESPONSE TYPES
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GenerateContentResponse {
|
||||||
|
pub candidates: Option<Vec<Candidate>>,
|
||||||
|
pub usage_metadata: Option<UsageMetadata>,
|
||||||
|
pub prompt_feedback: Option<PromptFeedback>,
|
||||||
|
pub model_version: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Candidate {
|
||||||
|
pub content: Option<Content>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct UsageMetadata {
|
||||||
|
pub prompt_token_count: Option<u32>,
|
||||||
|
pub candidates_token_count: Option<u32>,
|
||||||
|
pub total_token_count: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenUsage for UsageMetadata {
|
||||||
|
fn completion_tokens(&self) -> usize {
|
||||||
|
self.candidates_token_count.unwrap_or(0) as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_tokens(&self) -> usize {
|
||||||
|
self.prompt_token_count.unwrap_or(0) as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_tokens(&self) -> usize {
|
||||||
|
self.total_token_count.unwrap_or(0) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PromptFeedback {
|
||||||
|
pub block_reason: Option<String>,
|
||||||
|
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[skip_serializing_none]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SafetyRating {
|
||||||
|
pub category: String,
|
||||||
|
pub probability: String,
|
||||||
|
pub blocked: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// PROVIDER REQUEST TRAIT IMPLEMENTATION
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl ProviderRequest for GenerateContentRequest {
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_model(&mut self, model: String) {
|
||||||
|
self.model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_streaming(&self) -> bool {
|
||||||
|
// Gemini uses URL-based streaming, not a field in the request body
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_messages_text(&self) -> String {
|
||||||
|
let mut parts_text = Vec::new();
|
||||||
|
for content in &self.contents {
|
||||||
|
for part in &content.parts {
|
||||||
|
if let Some(text) = &part.text {
|
||||||
|
parts_text.push(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(system) = &self.system_instruction {
|
||||||
|
for part in &system.parts {
|
||||||
|
if let Some(text) = &part.text {
|
||||||
|
parts_text.push(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parts_text.join(" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_recent_user_message(&self) -> Option<String> {
|
||||||
|
self.contents
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|c| c.role.as_deref() == Some("user"))
|
||||||
|
.and_then(|c| {
|
||||||
|
c.parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.first()
|
||||||
|
.cloned()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||||
|
self.tools.as_ref().map(|tools| {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| t.function_declarations.as_ref())
|
||||||
|
.flatten()
|
||||||
|
.map(|f| f.name.clone())
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||||
|
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
|
||||||
|
message: format!("Failed to serialize GenerateContentRequest: {}", e),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||||
|
&self.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||||
|
if let Some(ref mut metadata) = self.metadata {
|
||||||
|
metadata.remove(key).is_some()
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.generation_config
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|gc| gc.temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||||
|
use crate::apis::openai::{Message, MessageContent, Role};
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
|
// Convert system instruction
|
||||||
|
if let Some(system) = &self.system_instruction {
|
||||||
|
let text = system
|
||||||
|
.parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("");
|
||||||
|
if !text.is_empty() {
|
||||||
|
messages.push(Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: Some(MessageContent::Text(text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert contents
|
||||||
|
for content in &self.contents {
|
||||||
|
let role = match content.role.as_deref() {
|
||||||
|
Some("model") => Role::Assistant,
|
||||||
|
_ => Role::User,
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = content
|
||||||
|
.parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
messages.push(Message {
|
||||||
|
role,
|
||||||
|
content: Some(MessageContent::Text(text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||||
|
use crate::apis::openai::Role;
|
||||||
|
|
||||||
|
self.contents.clear();
|
||||||
|
self.system_instruction = None;
|
||||||
|
|
||||||
|
for msg in messages {
|
||||||
|
let text = msg.content.extract_text();
|
||||||
|
match msg.role {
|
||||||
|
Role::System => {
|
||||||
|
self.system_instruction = Some(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::User => {
|
||||||
|
self.contents.push(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::Assistant => {
|
||||||
|
self.contents.push(Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::Tool => {
|
||||||
|
self.contents.push(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// PROVIDER STREAM RESPONSE TRAIT IMPLEMENTATION
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl ProviderStreamResponse for GenerateContentResponse {
|
||||||
|
fn content_delta(&self) -> Option<&str> {
|
||||||
|
self.candidates
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|candidates| candidates.first())
|
||||||
|
.and_then(|candidate| candidate.content.as_ref())
|
||||||
|
.and_then(|content| content.parts.first())
|
||||||
|
.and_then(|part| part.text.as_deref())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_final(&self) -> bool {
|
||||||
|
self.candidates
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|candidates| candidates.first())
|
||||||
|
.and_then(|candidate| candidate.finish_reason.as_deref())
|
||||||
|
.map(|reason| reason == "STOP" || reason == "MAX_TOKENS" || reason == "SAFETY")
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn role(&self) -> Option<&str> {
|
||||||
|
self.candidates
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|candidates| candidates.first())
|
||||||
|
.and_then(|candidate| candidate.content.as_ref())
|
||||||
|
.and_then(|content| content.role.as_deref())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn event_type(&self) -> Option<&str> {
|
||||||
|
None // Gemini doesn't use SSE event types
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// SERDE PARSING
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for GenerateContentRequest {
|
||||||
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for GenerateContentResponse {
|
||||||
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// TESTS
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_api_from_endpoint() {
|
||||||
|
assert_eq!(
|
||||||
|
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:generateContent"),
|
||||||
|
Some(GeminiApi::GenerateContent)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
GeminiApi::from_endpoint("/v1beta/models/gemini-pro:streamGenerateContent"),
|
||||||
|
Some(GeminiApi::StreamGenerateContent)
|
||||||
|
);
|
||||||
|
assert_eq!(GeminiApi::from_endpoint("/v1/chat/completions"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_content_request_serde() {
|
||||||
|
let json_str = json!({
|
||||||
|
"contents": [{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "Hello"}]
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"maxOutputTokens": 1024
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
|
||||||
|
assert_eq!(req.contents.len(), 1);
|
||||||
|
assert_eq!(req.contents[0].role, Some("user".to_string()));
|
||||||
|
assert_eq!(
|
||||||
|
req.generation_config.as_ref().unwrap().temperature,
|
||||||
|
Some(0.7)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
req.generation_config.as_ref().unwrap().max_output_tokens,
|
||||||
|
Some(1024)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Roundtrip
|
||||||
|
let bytes = serde_json::to_vec(&req).unwrap();
|
||||||
|
let req2: GenerateContentRequest = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
assert_eq!(req2.contents.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_content_response_serde() {
|
||||||
|
let json_str = json!({
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": "Hello! How can I help?"}]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP"
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 5,
|
||||||
|
"candidatesTokenCount": 7,
|
||||||
|
"totalTokenCount": 12
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
|
||||||
|
assert!(resp.candidates.is_some());
|
||||||
|
let candidates = resp.candidates.as_ref().unwrap();
|
||||||
|
assert_eq!(candidates.len(), 1);
|
||||||
|
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
|
||||||
|
assert_eq!(
|
||||||
|
resp.usage_metadata.as_ref().unwrap().prompt_token_count,
|
||||||
|
Some(5)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_content_request_with_tools() {
|
||||||
|
let json_str = json!({
|
||||||
|
"contents": [{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "What's the weather?"}]
|
||||||
|
}],
|
||||||
|
"tools": [{
|
||||||
|
"functionDeclarations": [{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}],
|
||||||
|
"toolConfig": {
|
||||||
|
"functionCallingConfig": {
|
||||||
|
"mode": "AUTO"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap();
|
||||||
|
assert!(req.tools.is_some());
|
||||||
|
let tools = req.tools.as_ref().unwrap();
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
let decls = tools[0].function_declarations.as_ref().unwrap();
|
||||||
|
assert_eq!(decls[0].name, "get_weather");
|
||||||
|
assert_eq!(
|
||||||
|
req.tool_config
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.function_calling_config
|
||||||
|
.mode,
|
||||||
|
"AUTO"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_content_response_with_function_call() {
|
||||||
|
let json_str = json!({
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"functionCall": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"args": {"location": "NYC"}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP"
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap();
|
||||||
|
let candidates = resp.candidates.as_ref().unwrap();
|
||||||
|
let parts = &candidates[0].content.as_ref().unwrap().parts;
|
||||||
|
assert!(parts[0].function_call.is_some());
|
||||||
|
assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_response_content_delta() {
|
||||||
|
let resp = GenerateContentResponse {
|
||||||
|
candidates: Some(vec![Candidate {
|
||||||
|
content: Some(Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Hello".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
finish_reason: None,
|
||||||
|
safety_ratings: None,
|
||||||
|
}]),
|
||||||
|
usage_metadata: None,
|
||||||
|
prompt_feedback: None,
|
||||||
|
model_version: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(resp.content_delta(), Some("Hello"));
|
||||||
|
assert!(!resp.is_final());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_response_is_final() {
|
||||||
|
let resp = GenerateContentResponse {
|
||||||
|
candidates: Some(vec![Candidate {
|
||||||
|
content: Some(Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Done".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
finish_reason: Some("STOP".to_string()),
|
||||||
|
safety_ratings: None,
|
||||||
|
}]),
|
||||||
|
usage_metadata: None,
|
||||||
|
prompt_feedback: None,
|
||||||
|
model_version: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(resp.is_final());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_request_extract_text() {
|
||||||
|
let req = GenerateContentRequest {
|
||||||
|
model: "gemini-pro".to_string(),
|
||||||
|
contents: vec![Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Hello world".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system_instruction: Some(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Be helpful".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = req.extract_messages_text();
|
||||||
|
assert!(text.contains("Hello world"));
|
||||||
|
assert!(text.contains("Be helpful"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_request_get_tool_names() {
|
||||||
|
let req = GenerateContentRequest {
|
||||||
|
model: "gemini-pro".to_string(),
|
||||||
|
contents: vec![],
|
||||||
|
tools: Some(vec![Tool {
|
||||||
|
function_declarations: Some(vec![
|
||||||
|
FunctionDeclaration {
|
||||||
|
name: "func_a".to_string(),
|
||||||
|
description: None,
|
||||||
|
parameters: None,
|
||||||
|
},
|
||||||
|
FunctionDeclaration {
|
||||||
|
name: "func_b".to_string(),
|
||||||
|
description: None,
|
||||||
|
parameters: None,
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
code_execution: None,
|
||||||
|
}]),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let names = req.get_tool_names().unwrap();
|
||||||
|
assert_eq!(names, vec!["func_a", "func_b"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod amazon_bedrock;
|
pub mod amazon_bedrock;
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod gemini;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod openai_responses;
|
pub mod openai_responses;
|
||||||
pub mod streaming_shapes;
|
pub mod streaming_shapes;
|
||||||
|
|
@ -10,6 +11,7 @@ pub use amazon_bedrock::{
|
||||||
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||||
};
|
};
|
||||||
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
||||||
|
pub use gemini::{GeminiApi, GenerateContentRequest, GenerateContentResponse};
|
||||||
pub use openai::{
|
pub use openai::{
|
||||||
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
|
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
|
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, GeminiApi, OpenAIApi};
|
||||||
use crate::ProviderId;
|
use crate::ProviderId;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
|
@ -8,6 +8,7 @@ pub enum SupportedAPIsFromClient {
|
||||||
OpenAIChatCompletions(OpenAIApi),
|
OpenAIChatCompletions(OpenAIApi),
|
||||||
AnthropicMessagesAPI(AnthropicApi),
|
AnthropicMessagesAPI(AnthropicApi),
|
||||||
OpenAIResponsesAPI(OpenAIApi),
|
OpenAIResponsesAPI(OpenAIApi),
|
||||||
|
GeminiGenerateContentAPI(GeminiApi),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
|
@ -17,6 +18,8 @@ pub enum SupportedUpstreamAPIs {
|
||||||
AmazonBedrockConverse(AmazonBedrockApi),
|
AmazonBedrockConverse(AmazonBedrockApi),
|
||||||
AmazonBedrockConverseStream(AmazonBedrockApi),
|
AmazonBedrockConverseStream(AmazonBedrockApi),
|
||||||
OpenAIResponsesAPI(OpenAIApi),
|
OpenAIResponsesAPI(OpenAIApi),
|
||||||
|
GeminiGenerateContent(GeminiApi),
|
||||||
|
GeminiStreamGenerateContent(GeminiApi),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for SupportedAPIsFromClient {
|
impl fmt::Display for SupportedAPIsFromClient {
|
||||||
|
|
@ -31,6 +34,9 @@ impl fmt::Display for SupportedAPIsFromClient {
|
||||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
|
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
|
||||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||||
}
|
}
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => {
|
||||||
|
write!(f, "Gemini ({})", api.endpoint())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -53,6 +59,12 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
||||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
|
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
|
||||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||||
}
|
}
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(api) => {
|
||||||
|
write!(f, "Gemini ({})", api.endpoint())
|
||||||
|
}
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(api) => {
|
||||||
|
write!(f, "Gemini Stream ({})", api.endpoint())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -60,6 +72,13 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
||||||
impl SupportedAPIsFromClient {
|
impl SupportedAPIsFromClient {
|
||||||
/// Create a SupportedApi from an endpoint path
|
/// Create a SupportedApi from an endpoint path
|
||||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||||
|
// Check Gemini first since it uses suffix matching (`:generateContent`)
|
||||||
|
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
|
||||||
|
return Some(SupportedAPIsFromClient::GeminiGenerateContentAPI(
|
||||||
|
gemini_api,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||||
// Check if this is the Responses API endpoint
|
// Check if this is the Responses API endpoint
|
||||||
if openai_api == OpenAIApi::Responses {
|
if openai_api == OpenAIApi::Responses {
|
||||||
|
|
@ -82,6 +101,7 @@ impl SupportedAPIsFromClient {
|
||||||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
|
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
|
||||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
|
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
|
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => api.endpoint(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -145,7 +165,18 @@ impl SupportedAPIsFromClient {
|
||||||
}
|
}
|
||||||
ProviderId::Gemini => {
|
ProviderId::Gemini => {
|
||||||
if request_path.starts_with("/v1/") {
|
if request_path.starts_with("/v1/") {
|
||||||
build_endpoint("/v1beta/openai", endpoint_suffix)
|
// Use native Gemini endpoint
|
||||||
|
if !is_streaming {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:generateContent", model_id),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
build_endpoint("/v1", endpoint_suffix)
|
build_endpoint("/v1", endpoint_suffix)
|
||||||
}
|
}
|
||||||
|
|
@ -178,6 +209,20 @@ impl SupportedAPIsFromClient {
|
||||||
build_endpoint("/v1", "/chat/completions")
|
build_endpoint("/v1", "/chat/completions")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ProviderId::Gemini => {
|
||||||
|
// Translate Anthropic → Gemini native
|
||||||
|
if !is_streaming {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:generateContent", model_id),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => build_endpoint("/v1", "/chat/completions"),
|
_ => build_endpoint("/v1", "/chat/completions"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -186,6 +231,20 @@ impl SupportedAPIsFromClient {
|
||||||
match provider_id {
|
match provider_id {
|
||||||
// Providers that support /v1/responses natively
|
// Providers that support /v1/responses natively
|
||||||
ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"),
|
ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"),
|
||||||
|
ProviderId::Gemini => {
|
||||||
|
// Translate Responses → Gemini native
|
||||||
|
if !is_streaming {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:generateContent", model_id),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
// All other providers: translate to /chat/completions
|
// All other providers: translate to /chat/completions
|
||||||
_ => route_by_provider("/chat/completions"),
|
_ => route_by_provider("/chat/completions"),
|
||||||
}
|
}
|
||||||
|
|
@ -194,6 +253,33 @@ impl SupportedAPIsFromClient {
|
||||||
// For Chat Completions API, use the standard chat/completions path
|
// For Chat Completions API, use the standard chat/completions path
|
||||||
route_by_provider("/chat/completions")
|
route_by_provider("/chat/completions")
|
||||||
}
|
}
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
|
||||||
|
match provider_id {
|
||||||
|
ProviderId::Gemini => {
|
||||||
|
// Native Gemini endpoint
|
||||||
|
if !is_streaming {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:generateContent", model_id),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
build_endpoint(
|
||||||
|
"/v1beta",
|
||||||
|
&format!("/models/{}:streamGenerateContent?alt=sse", model_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||||
|
ProviderId::AmazonBedrock => {
|
||||||
|
if !is_streaming {
|
||||||
|
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||||
|
} else {
|
||||||
|
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => build_endpoint("/v1", "/chat/completions"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -201,6 +287,18 @@ impl SupportedAPIsFromClient {
|
||||||
impl SupportedUpstreamAPIs {
|
impl SupportedUpstreamAPIs {
|
||||||
/// Create a SupportedUpstreamApi from an endpoint path
|
/// Create a SupportedUpstreamApi from an endpoint path
|
||||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||||
|
// Check Gemini first since it uses suffix matching
|
||||||
|
if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) {
|
||||||
|
return match gemini_api {
|
||||||
|
GeminiApi::GenerateContent => {
|
||||||
|
Some(SupportedUpstreamAPIs::GeminiGenerateContent(gemini_api))
|
||||||
|
}
|
||||||
|
GeminiApi::StreamGenerateContent => Some(
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(gemini_api),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||||
// Check if this is the Responses API endpoint
|
// Check if this is the Responses API endpoint
|
||||||
if openai_api == OpenAIApi::Responses {
|
if openai_api == OpenAIApi::Responses {
|
||||||
|
|
@ -396,7 +494,7 @@ mod tests {
|
||||||
"/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview"
|
"/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test Gemini provider
|
// Test Gemini provider (uses native Gemini API with transforms)
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
api.target_endpoint_for_provider(
|
api.target_endpoint_for_provider(
|
||||||
&ProviderId::Gemini,
|
&ProviderId::Gemini,
|
||||||
|
|
@ -405,7 +503,7 @@ mod tests {
|
||||||
false,
|
false,
|
||||||
None
|
None
|
||||||
),
|
),
|
||||||
"/v1beta/openai/chat/completions"
|
"/v1beta/models/gemini-pro:generateContent"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamRe
|
||||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||||
|
pub const GENERATE_CONTENT_PATH_SUFFIX: &str = ":generateContent";
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
use crate::apis::{AmazonBedrockApi, AnthropicApi, GeminiApi, OpenAIApi};
|
||||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -116,7 +116,68 @@ impl ProviderId {
|
||||||
is_streaming: bool,
|
is_streaming: bool,
|
||||||
) -> SupportedUpstreamAPIs {
|
) -> SupportedUpstreamAPIs {
|
||||||
match (self, client_api) {
|
match (self, client_api) {
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini provider — use native Gemini APIs
|
||||||
|
// ============================================================================
|
||||||
|
(ProviderId::Gemini, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||||
|
if is_streaming {
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||||
|
GeminiApi::StreamGenerateContent,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
|
||||||
|
if is_streaming {
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||||
|
GeminiApi::StreamGenerateContent,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ProviderId::Gemini, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||||
|
if is_streaming {
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||||
|
GeminiApi::StreamGenerateContent,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ProviderId::Gemini, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||||
|
if is_streaming {
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(
|
||||||
|
GeminiApi::StreamGenerateContent,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Non-Gemini providers receiving Gemini-format requests
|
||||||
|
// ============================================================================
|
||||||
|
(ProviderId::Anthropic, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||||
|
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||||
|
}
|
||||||
|
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||||
|
if is_streaming {
|
||||||
|
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||||
|
AmazonBedrockApi::ConverseStream,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(_, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => {
|
||||||
|
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
// Claude/Anthropic providers natively support Anthropic APIs
|
// Claude/Anthropic providers natively support Anthropic APIs
|
||||||
|
// ============================================================================
|
||||||
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||||
}
|
}
|
||||||
|
|
@ -136,7 +197,6 @@ impl ProviderId {
|
||||||
| ProviderId::Mistral
|
| ProviderId::Mistral
|
||||||
| ProviderId::Deepseek
|
| ProviderId::Deepseek
|
||||||
| ProviderId::Arch
|
| ProviderId::Arch
|
||||||
| ProviderId::Gemini
|
|
||||||
| ProviderId::GitHub
|
| ProviderId::GitHub
|
||||||
| ProviderId::AzureOpenAI
|
| ProviderId::AzureOpenAI
|
||||||
| ProviderId::XAI
|
| ProviderId::XAI
|
||||||
|
|
@ -154,7 +214,6 @@ impl ProviderId {
|
||||||
| ProviderId::Mistral
|
| ProviderId::Mistral
|
||||||
| ProviderId::Deepseek
|
| ProviderId::Deepseek
|
||||||
| ProviderId::Arch
|
| ProviderId::Arch
|
||||||
| ProviderId::Gemini
|
|
||||||
| ProviderId::GitHub
|
| ProviderId::GitHub
|
||||||
| ProviderId::AzureOpenAI
|
| ProviderId::AzureOpenAI
|
||||||
| ProviderId::XAI
|
| ProviderId::XAI
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
use crate::apis::anthropic::MessagesRequest;
|
use crate::apis::anthropic::MessagesRequest;
|
||||||
|
use crate::apis::gemini::GenerateContentRequest;
|
||||||
use crate::apis::openai::ChatCompletionsRequest;
|
use crate::apis::openai::ChatCompletionsRequest;
|
||||||
|
use crate::apis::ApiDefinition;
|
||||||
|
|
||||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||||
use crate::apis::openai_responses::ResponsesAPIRequest;
|
use crate::apis::openai_responses::ResponsesAPIRequest;
|
||||||
|
|
@ -19,7 +21,8 @@ pub enum ProviderRequestType {
|
||||||
BedrockConverse(ConverseRequest),
|
BedrockConverse(ConverseRequest),
|
||||||
BedrockConverseStream(ConverseStreamRequest),
|
BedrockConverseStream(ConverseStreamRequest),
|
||||||
ResponsesAPIRequest(ResponsesAPIRequest),
|
ResponsesAPIRequest(ResponsesAPIRequest),
|
||||||
//add more request types here
|
GeminiGenerateContent(GenerateContentRequest),
|
||||||
|
GeminiStreamGenerateContent(GenerateContentRequest),
|
||||||
}
|
}
|
||||||
pub trait ProviderRequest: Send + Sync {
|
pub trait ProviderRequest: Send + Sync {
|
||||||
/// Extract the model name from the request
|
/// Extract the model name from the request
|
||||||
|
|
@ -69,6 +72,9 @@ impl ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.set_messages(messages),
|
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.set_messages(messages)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,6 +106,7 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.model(),
|
Self::BedrockConverse(r) => r.model(),
|
||||||
Self::BedrockConverseStream(r) => r.model(),
|
Self::BedrockConverseStream(r) => r.model(),
|
||||||
Self::ResponsesAPIRequest(r) => r.model(),
|
Self::ResponsesAPIRequest(r) => r.model(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.model(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,6 +117,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.set_model(model),
|
Self::BedrockConverse(r) => r.set_model(model),
|
||||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||||
Self::ResponsesAPIRequest(r) => r.set_model(model),
|
Self::ResponsesAPIRequest(r) => r.set_model(model),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.set_model(model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -120,6 +130,8 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(_) => false,
|
Self::BedrockConverse(_) => false,
|
||||||
Self::BedrockConverseStream(_) => true,
|
Self::BedrockConverseStream(_) => true,
|
||||||
Self::ResponsesAPIRequest(r) => r.is_streaming(),
|
Self::ResponsesAPIRequest(r) => r.is_streaming(),
|
||||||
|
Self::GeminiGenerateContent(_) => false,
|
||||||
|
Self::GeminiStreamGenerateContent(_) => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,6 +142,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||||
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
|
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.extract_messages_text()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -140,6 +155,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||||
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
|
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.get_recent_user_message()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,6 +168,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.get_tool_names(),
|
Self::BedrockConverse(r) => r.get_tool_names(),
|
||||||
Self::BedrockConverseStream(r) => r.get_tool_names(),
|
Self::BedrockConverseStream(r) => r.get_tool_names(),
|
||||||
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
|
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.get_tool_names()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -160,6 +181,7 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.to_bytes(),
|
Self::BedrockConverse(r) => r.to_bytes(),
|
||||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||||
Self::ResponsesAPIRequest(r) => r.to_bytes(),
|
Self::ResponsesAPIRequest(r) => r.to_bytes(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.to_bytes(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,6 +192,7 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.metadata(),
|
Self::BedrockConverse(r) => r.metadata(),
|
||||||
Self::BedrockConverseStream(r) => r.metadata(),
|
Self::BedrockConverseStream(r) => r.metadata(),
|
||||||
Self::ResponsesAPIRequest(r) => r.metadata(),
|
Self::ResponsesAPIRequest(r) => r.metadata(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.metadata(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,6 +203,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||||
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.remove_metadata_key(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -190,6 +216,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.get_temperature(),
|
Self::BedrockConverse(r) => r.get_temperature(),
|
||||||
Self::BedrockConverseStream(r) => r.get_temperature(),
|
Self::BedrockConverseStream(r) => r.get_temperature(),
|
||||||
Self::ResponsesAPIRequest(r) => r.get_temperature(),
|
Self::ResponsesAPIRequest(r) => r.get_temperature(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.get_temperature()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -200,6 +229,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.get_messages(),
|
Self::BedrockConverse(r) => r.get_messages(),
|
||||||
Self::BedrockConverseStream(r) => r.get_messages(),
|
Self::BedrockConverseStream(r) => r.get_messages(),
|
||||||
Self::ResponsesAPIRequest(r) => r.get_messages(),
|
Self::ResponsesAPIRequest(r) => r.get_messages(),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.get_messages()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,6 +242,9 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::BedrockConverse(r) => r.set_messages(messages),
|
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||||
|
Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => {
|
||||||
|
r.set_messages(messages)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -245,6 +280,18 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
|
||||||
responses_apirequest,
|
responses_apirequest,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(gemini_api) => {
|
||||||
|
let gemini_request: GenerateContentRequest =
|
||||||
|
GenerateContentRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
if gemini_api.supports_streaming() {
|
||||||
|
Ok(ProviderRequestType::GeminiStreamGenerateContent(
|
||||||
|
gemini_request,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(ProviderRequestType::GeminiGenerateContent(gemini_request))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -309,6 +356,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
||||||
source: None,
|
source: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
// ChatCompletions -> Gemini
|
||||||
|
(
|
||||||
|
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// MessagesRequest conversions
|
// MessagesRequest conversions
|
||||||
|
|
@ -370,6 +448,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
||||||
source: None,
|
source: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
// Messages -> Gemini (chain: Anthropic -> OpenAI -> Gemini)
|
||||||
|
(
|
||||||
|
ProviderRequestType::MessagesRequest(messages_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert MessagesRequest to GenerateContentRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::MessagesRequest(messages_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert MessagesRequest to GenerateContentRequest (stream): {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// ResponsesAPIRequest conversions (only converts TO other formats)
|
// ResponsesAPIRequest conversions (only converts TO other formats)
|
||||||
|
|
@ -480,6 +589,171 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
||||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResponsesAPI -> Gemini (via ChatCompletions)
|
||||||
|
(
|
||||||
|
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
|
) => {
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// GeminiGenerateContent conversions (client sends Gemini format)
|
||||||
|
// ============================================================================
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
|
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
|
||||||
|
// Cross-streaming mode: non-streaming -> streaming and vice versa
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
|
) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)),
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)),
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||||
|
) => {
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||||
|
) => {
|
||||||
|
let messages_req = MessagesRequest::try_from(gemini_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert GenerateContentRequest to MessagesRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Gemini -> OpenAI -> Bedrock
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to ConverseRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(gemini_req)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(gemini_req),
|
||||||
|
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Gemini -> OpenAI -> Bedrock Stream
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
|
||||||
|
ProviderRequestError {
|
||||||
|
message: format!(
|
||||||
|
"Failed to convert ChatCompletionsRequest to ConverseStreamRequest: {}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ProviderRequestType::GeminiGenerateContent(_)
|
||||||
|
| ProviderRequestType::GeminiStreamGenerateContent(_),
|
||||||
|
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||||
|
) => {
|
||||||
|
Err(ProviderRequestError {
|
||||||
|
message: "Conversion from GenerateContentRequest to ResponsesAPIRequest is not supported.".to_string(),
|
||||||
|
source: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Amazon Bedrock conversions (not supported as client API)
|
// Amazon Bedrock conversions (not supported as client API)
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::apis::amazon_bedrock::ConverseResponse;
|
use crate::apis::amazon_bedrock::ConverseResponse;
|
||||||
use crate::apis::anthropic::MessagesResponse;
|
use crate::apis::anthropic::MessagesResponse;
|
||||||
|
use crate::apis::gemini::GenerateContentResponse;
|
||||||
use crate::apis::openai::ChatCompletionsResponse;
|
use crate::apis::openai::ChatCompletionsResponse;
|
||||||
use crate::apis::openai_responses::ResponsesAPIResponse;
|
use crate::apis::openai_responses::ResponsesAPIResponse;
|
||||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||||
|
|
@ -16,6 +17,7 @@ pub enum ProviderResponseType {
|
||||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||||
MessagesResponse(MessagesResponse),
|
MessagesResponse(MessagesResponse),
|
||||||
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
|
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
|
||||||
|
GenerateContentResponse(GenerateContentResponse),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for token usage information
|
/// Trait for token usage information
|
||||||
|
|
@ -44,6 +46,9 @@ impl ProviderResponse for ProviderResponseType {
|
||||||
ProviderResponseType::ResponsesAPIResponse(resp) => {
|
ProviderResponseType::ResponsesAPIResponse(resp) => {
|
||||||
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
|
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
|
||||||
}
|
}
|
||||||
|
ProviderResponseType::GenerateContentResponse(resp) => {
|
||||||
|
resp.usage_metadata.as_ref().map(|u| u as &dyn TokenUsage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,6 +63,15 @@ impl ProviderResponse for ProviderResponseType {
|
||||||
u.total_tokens as usize,
|
u.total_tokens as usize,
|
||||||
)
|
)
|
||||||
}),
|
}),
|
||||||
|
ProviderResponseType::GenerateContentResponse(resp) => {
|
||||||
|
resp.usage_metadata.as_ref().map(|u| {
|
||||||
|
(
|
||||||
|
u.prompt_token_count.unwrap_or(0) as usize,
|
||||||
|
u.candidates_token_count.unwrap_or(0) as usize,
|
||||||
|
u.total_token_count.unwrap_or(0) as usize,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -238,6 +252,140 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
|
||||||
response_api,
|
response_api,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini upstream transformations
|
||||||
|
// ============================================================================
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||||
|
) => {
|
||||||
|
// Passthrough: Gemini upstream -> Gemini client
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderResponseType::GenerateContentResponse(resp))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||||
|
) => {
|
||||||
|
// Gemini upstream -> OpenAI client
|
||||||
|
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Gemini -> OpenAI -> Anthropic
|
||||||
|
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let messages_resp: MessagesResponse = chat_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::GeminiGenerateContent(_),
|
||||||
|
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Gemini -> OpenAI -> ResponsesAPI
|
||||||
|
let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let responses_resp: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
|
||||||
|
responses_resp,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Non-Gemini upstream -> Gemini client
|
||||||
|
// ============================================================================
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||||
|
) => {
|
||||||
|
// OpenAI upstream -> Gemini client
|
||||||
|
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let gemini_resp: GenerateContentResponse = openai_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Anthropic -> OpenAI -> Gemini
|
||||||
|
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let chat_resp: ChatCompletionsResponse =
|
||||||
|
anthropic_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_),
|
||||||
|
) => {
|
||||||
|
// Chain: Bedrock -> OpenAI -> Gemini
|
||||||
|
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("Transformation error: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ProviderResponseType::GenerateContentResponse(gemini_resp))
|
||||||
|
}
|
||||||
|
|
||||||
_ => Err(std::io::Error::new(
|
_ => Err(std::io::Error::new(
|
||||||
std::io::ErrorKind::InvalidData,
|
std::io::ErrorKind::InvalidData,
|
||||||
"Unsupported API combination for response transformation",
|
"Unsupported API combination for response transformation",
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,11 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBu
|
||||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||||
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
|
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
|
||||||
}
|
}
|
||||||
|
SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => {
|
||||||
|
// Gemini client with a different upstream - use passthrough
|
||||||
|
// since Gemini streaming uses SSE and doesn't need special buffering
|
||||||
|
Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ pub mod response_streaming;
|
||||||
|
|
||||||
// Re-export commonly used items for convenience
|
// Re-export commonly used items for convenience
|
||||||
pub use lib::*;
|
pub use lib::*;
|
||||||
|
#[allow(ambiguous_glob_reexports)]
|
||||||
pub use request::*;
|
pub use request::*;
|
||||||
pub use response::*;
|
pub use response::*;
|
||||||
pub use response_streaming::*;
|
pub use response_streaming::*;
|
||||||
|
|
|
||||||
327
crates/hermesllm/src/transforms/request/from_gemini.rs
Normal file
327
crates/hermesllm/src/transforms/request/from_gemini.rs
Normal file
|
|
@ -0,0 +1,327 @@
|
||||||
|
use crate::apis::gemini::GenerateContentRequest;
|
||||||
|
use crate::apis::openai::{
|
||||||
|
ChatCompletionsRequest, Function, FunctionCall as OpenAIFunctionCall, Message, MessageContent,
|
||||||
|
Role, Tool, ToolCall as OpenAIToolCall, ToolChoice, ToolChoiceType,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::apis::anthropic::MessagesRequest;
|
||||||
|
use crate::clients::TransformError;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini GenerateContent -> OpenAI ChatCompletions
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<GenerateContentRequest> for ChatCompletionsRequest {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
|
||||||
|
let mut messages: Vec<Message> = Vec::new();
|
||||||
|
|
||||||
|
// Convert system instruction
|
||||||
|
if let Some(system) = &req.system_instruction {
|
||||||
|
let text = system
|
||||||
|
.parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("");
|
||||||
|
if !text.is_empty() {
|
||||||
|
messages.push(Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: Some(MessageContent::Text(text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert contents
|
||||||
|
for content in &req.contents {
|
||||||
|
let role = match content.role.as_deref() {
|
||||||
|
Some("model") => Role::Assistant,
|
||||||
|
_ => Role::User,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if this content has function_call parts (assistant with tool calls)
|
||||||
|
let has_function_calls = content.parts.iter().any(|p| p.function_call.is_some());
|
||||||
|
let has_function_responses =
|
||||||
|
content.parts.iter().any(|p| p.function_response.is_some());
|
||||||
|
|
||||||
|
if has_function_calls {
|
||||||
|
// Convert to assistant message with tool_calls
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
|
||||||
|
for (i, part) in content.parts.iter().enumerate() {
|
||||||
|
if let Some(fc) = &part.function_call {
|
||||||
|
tool_calls.push(OpenAIToolCall {
|
||||||
|
id: format!("call_{}", i),
|
||||||
|
call_type: "function".to_string(),
|
||||||
|
function: OpenAIFunctionCall {
|
||||||
|
name: fc.name.clone(),
|
||||||
|
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else if let Some(text) = &part.text {
|
||||||
|
text_parts.push(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let content_text = if text_parts.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(MessageContent::Text(text_parts.join("")))
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.push(Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: content_text,
|
||||||
|
name: None,
|
||||||
|
tool_calls: if tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tool_calls)
|
||||||
|
},
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
} else if has_function_responses {
|
||||||
|
// Convert each function_response to a tool message
|
||||||
|
for part in &content.parts {
|
||||||
|
if let Some(fr) = &part.function_response {
|
||||||
|
let result_text = serde_json::to_string(&fr.response).unwrap_or_default();
|
||||||
|
messages.push(Message {
|
||||||
|
role: Role::Tool,
|
||||||
|
content: Some(MessageContent::Text(result_text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(fr.name.clone()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular text message
|
||||||
|
let text = content
|
||||||
|
.parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("");
|
||||||
|
messages.push(Message {
|
||||||
|
role,
|
||||||
|
content: Some(MessageContent::Text(text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert generation config
|
||||||
|
let (temperature, top_p, max_tokens, stop, presence_penalty, frequency_penalty) =
|
||||||
|
if let Some(gc) = &req.generation_config {
|
||||||
|
(
|
||||||
|
gc.temperature,
|
||||||
|
gc.top_p,
|
||||||
|
gc.max_output_tokens,
|
||||||
|
gc.stop_sequences.clone(),
|
||||||
|
gc.presence_penalty,
|
||||||
|
gc.frequency_penalty,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(None, None, None, None, None, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert tools
|
||||||
|
let tools = req.tools.and_then(|gemini_tools| {
|
||||||
|
let openai_tools: Vec<Tool> = gemini_tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| t.function_declarations.as_ref())
|
||||||
|
.flatten()
|
||||||
|
.map(|fd| Tool {
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: fd.name.clone(),
|
||||||
|
description: fd.description.clone(),
|
||||||
|
parameters: fd.parameters.clone().unwrap_or_default(),
|
||||||
|
strict: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if openai_tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(openai_tools)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert tool_config
|
||||||
|
let tool_choice =
|
||||||
|
req.tool_config
|
||||||
|
.and_then(|tc| match tc.function_calling_config.mode.as_str() {
|
||||||
|
"AUTO" => Some(ToolChoice::Type(ToolChoiceType::Auto)),
|
||||||
|
"NONE" => Some(ToolChoice::Type(ToolChoiceType::None)),
|
||||||
|
"ANY" => Some(ToolChoice::Type(ToolChoiceType::Required)),
|
||||||
|
_ => None,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ChatCompletionsRequest {
|
||||||
|
model: req.model,
|
||||||
|
messages,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
max_completion_tokens: max_tokens,
|
||||||
|
stop,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
metadata: req.metadata,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini GenerateContent -> Anthropic Messages (via OpenAI)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<GenerateContentRequest> for MessagesRequest {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(req: GenerateContentRequest) -> Result<Self, Self::Error> {
|
||||||
|
// Chain: Gemini -> OpenAI -> Anthropic
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(req)?;
|
||||||
|
MessagesRequest::try_from(chat_req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// TESTS
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::apis::gemini::{Content, FunctionCall, Part};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_basic() {
|
||||||
|
let req = GenerateContentRequest {
|
||||||
|
model: "gemini-pro".to_string(),
|
||||||
|
contents: vec![
|
||||||
|
Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Hello".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Hi there!".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
system_instruction: Some(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Be helpful".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
generation_config: Some(crate::apis::gemini::GenerationConfig {
|
||||||
|
temperature: Some(0.5),
|
||||||
|
max_output_tokens: Some(512),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||||
|
|
||||||
|
// System + user + assistant = 3 messages
|
||||||
|
assert_eq!(openai_req.messages.len(), 3);
|
||||||
|
assert_eq!(openai_req.messages[0].role, Role::System);
|
||||||
|
assert_eq!(openai_req.messages[1].role, Role::User);
|
||||||
|
assert_eq!(openai_req.messages[2].role, Role::Assistant);
|
||||||
|
|
||||||
|
assert_eq!(openai_req.temperature, Some(0.5));
|
||||||
|
assert_eq!(openai_req.max_completion_tokens, Some(512));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_with_function_calls() {
|
||||||
|
let req = GenerateContentRequest {
|
||||||
|
model: "gemini-pro".to_string(),
|
||||||
|
contents: vec![
|
||||||
|
Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("Weather?".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: None,
|
||||||
|
inline_data: None,
|
||||||
|
function_call: Some(FunctionCall {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
args: json!({"location": "NYC"}),
|
||||||
|
}),
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||||
|
assert_eq!(openai_req.messages.len(), 2);
|
||||||
|
assert!(openai_req.messages[1].tool_calls.is_some());
|
||||||
|
let tc = openai_req.messages[1].tool_calls.as_ref().unwrap();
|
||||||
|
assert_eq!(tc[0].function.name, "get_weather");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_tool_config() {
|
||||||
|
let req = GenerateContentRequest {
|
||||||
|
model: "gemini-pro".to_string(),
|
||||||
|
contents: vec![Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some("test".to_string()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
tool_config: Some(crate::apis::gemini::ToolConfig {
|
||||||
|
function_calling_config: crate::apis::gemini::FunctionCallingConfig {
|
||||||
|
mode: "ANY".to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let openai_req = ChatCompletionsRequest::try_from(req).unwrap();
|
||||||
|
assert!(openai_req.tool_choice.is_some());
|
||||||
|
assert_eq!(
|
||||||
|
openai_req.tool_choice.as_ref().unwrap(),
|
||||||
|
&ToolChoice::Type(ToolChoiceType::Required)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
//! Request transformation modules
|
//! Request transformation modules
|
||||||
|
|
||||||
pub mod from_anthropic;
|
pub mod from_anthropic;
|
||||||
|
pub mod from_gemini;
|
||||||
pub mod from_openai;
|
pub mod from_openai;
|
||||||
|
pub mod to_gemini;
|
||||||
|
|
|
||||||
323
crates/hermesllm/src/transforms/request/to_gemini.rs
Normal file
323
crates/hermesllm/src/transforms/request/to_gemini.rs
Normal file
|
|
@ -0,0 +1,323 @@
|
||||||
|
use crate::apis::gemini::{
|
||||||
|
Content, FunctionCall, FunctionCallingConfig, FunctionDeclaration, FunctionResponse,
|
||||||
|
GenerateContentRequest, GenerationConfig, Part, Tool, ToolConfig,
|
||||||
|
};
|
||||||
|
use crate::apis::openai::{ChatCompletionsRequest, Role, ToolChoice, ToolChoiceType};
|
||||||
|
|
||||||
|
use crate::apis::anthropic::MessagesRequest;
|
||||||
|
use crate::clients::TransformError;
|
||||||
|
use crate::transforms::lib::ExtractText;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// OpenAI ChatCompletions -> Gemini GenerateContent
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<ChatCompletionsRequest> for GenerateContentRequest {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
|
||||||
|
let mut contents: Vec<Content> = Vec::new();
|
||||||
|
let mut system_instruction: Option<Content> = None;
|
||||||
|
|
||||||
|
for msg in &req.messages {
|
||||||
|
match msg.role {
|
||||||
|
Role::System => {
|
||||||
|
let text = msg.content.extract_text();
|
||||||
|
system_instruction = Some(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::User => {
|
||||||
|
let text = msg.content.extract_text();
|
||||||
|
contents.push(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::Assistant => {
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
|
||||||
|
// Check for tool calls
|
||||||
|
if let Some(tool_calls) = &msg.tool_calls {
|
||||||
|
for tc in tool_calls {
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||||
|
parts.push(Part {
|
||||||
|
text: None,
|
||||||
|
inline_data: None,
|
||||||
|
function_call: Some(FunctionCall {
|
||||||
|
name: tc.function.name.clone(),
|
||||||
|
args,
|
||||||
|
}),
|
||||||
|
function_response: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also include text content if present
|
||||||
|
let text = msg.content.extract_text();
|
||||||
|
if !text.is_empty() {
|
||||||
|
parts.push(Part {
|
||||||
|
text: Some(text),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !parts.is_empty() {
|
||||||
|
contents.push(Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Role::Tool => {
|
||||||
|
let text = msg.content.extract_text();
|
||||||
|
let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
|
||||||
|
let response_value = serde_json::from_str(&text)
|
||||||
|
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
||||||
|
|
||||||
|
contents.push(Content {
|
||||||
|
role: Some("user".to_string()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: None,
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: Some(FunctionResponse {
|
||||||
|
name: tool_call_id,
|
||||||
|
response: response_value,
|
||||||
|
}),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert generation config
|
||||||
|
let generation_config = {
|
||||||
|
let gc = GenerationConfig {
|
||||||
|
temperature: req.temperature,
|
||||||
|
top_p: req.top_p,
|
||||||
|
top_k: None,
|
||||||
|
max_output_tokens: req.max_completion_tokens.or(req.max_tokens),
|
||||||
|
stop_sequences: req.stop,
|
||||||
|
response_mime_type: None,
|
||||||
|
candidate_count: None,
|
||||||
|
presence_penalty: req.presence_penalty,
|
||||||
|
frequency_penalty: req.frequency_penalty,
|
||||||
|
};
|
||||||
|
// Only include if any field is set
|
||||||
|
if gc.temperature.is_some()
|
||||||
|
|| gc.top_p.is_some()
|
||||||
|
|| gc.max_output_tokens.is_some()
|
||||||
|
|| gc.stop_sequences.is_some()
|
||||||
|
|| gc.presence_penalty.is_some()
|
||||||
|
|| gc.frequency_penalty.is_some()
|
||||||
|
{
|
||||||
|
Some(gc)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert tools
|
||||||
|
let tools = req.tools.map(|openai_tools| {
|
||||||
|
let declarations: Vec<FunctionDeclaration> = openai_tools
|
||||||
|
.iter()
|
||||||
|
.map(|t| FunctionDeclaration {
|
||||||
|
name: t.function.name.clone(),
|
||||||
|
description: t.function.description.clone(),
|
||||||
|
parameters: Some(t.function.parameters.clone()),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
vec![Tool {
|
||||||
|
function_declarations: Some(declarations),
|
||||||
|
code_execution: None,
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert tool_choice
|
||||||
|
let tool_config = req.tool_choice.and_then(|tc| {
|
||||||
|
let mode = match tc {
|
||||||
|
ToolChoice::Type(t) => match t {
|
||||||
|
ToolChoiceType::Auto => Some("AUTO".to_string()),
|
||||||
|
ToolChoiceType::None => Some("NONE".to_string()),
|
||||||
|
ToolChoiceType::Required => Some("ANY".to_string()),
|
||||||
|
},
|
||||||
|
ToolChoice::Function { .. } => Some("AUTO".to_string()),
|
||||||
|
};
|
||||||
|
mode.map(|m| ToolConfig {
|
||||||
|
function_calling_config: FunctionCallingConfig { mode: m },
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(GenerateContentRequest {
|
||||||
|
model: req.model,
|
||||||
|
contents,
|
||||||
|
generation_config,
|
||||||
|
tools,
|
||||||
|
tool_config,
|
||||||
|
safety_settings: None,
|
||||||
|
system_instruction,
|
||||||
|
cached_content: None,
|
||||||
|
metadata: req.metadata,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Anthropic Messages -> Gemini GenerateContent (via OpenAI)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<MessagesRequest> for GenerateContentRequest {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(req: MessagesRequest) -> Result<Self, Self::Error> {
|
||||||
|
// Chain: Anthropic -> OpenAI -> Gemini
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(req)?;
|
||||||
|
GenerateContentRequest::try_from(chat_req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// TESTS
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_to_gemini_basic() {
|
||||||
|
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||||
|
"model": "gemini-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1024
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||||
|
|
||||||
|
// System should be in system_instruction
|
||||||
|
assert!(gemini_req.system_instruction.is_some());
|
||||||
|
let sys = gemini_req.system_instruction.as_ref().unwrap();
|
||||||
|
assert_eq!(sys.parts[0].text.as_deref(), Some("You are helpful"));
|
||||||
|
|
||||||
|
// 3 content messages (user, model, user)
|
||||||
|
assert_eq!(gemini_req.contents.len(), 3);
|
||||||
|
assert_eq!(gemini_req.contents[0].role.as_deref(), Some("user"));
|
||||||
|
assert_eq!(gemini_req.contents[1].role.as_deref(), Some("model"));
|
||||||
|
assert_eq!(gemini_req.contents[2].role.as_deref(), Some("user"));
|
||||||
|
|
||||||
|
// Generation config
|
||||||
|
assert_eq!(
|
||||||
|
gemini_req.generation_config.as_ref().unwrap().temperature,
|
||||||
|
Some(0.7)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
gemini_req
|
||||||
|
.generation_config
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.max_output_tokens,
|
||||||
|
Some(1024)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_to_gemini_with_tools() {
|
||||||
|
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||||
|
"model": "gemini-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather?"}
|
||||||
|
],
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather",
|
||||||
|
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"tool_choice": "auto"
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||||
|
assert!(gemini_req.tools.is_some());
|
||||||
|
let tools = gemini_req.tools.as_ref().unwrap();
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
let decls = tools[0].function_declarations.as_ref().unwrap();
|
||||||
|
assert_eq!(decls[0].name, "get_weather");
|
||||||
|
|
||||||
|
assert!(gemini_req.tool_config.is_some());
|
||||||
|
assert_eq!(
|
||||||
|
gemini_req
|
||||||
|
.tool_config
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.function_calling_config
|
||||||
|
.mode,
|
||||||
|
"AUTO"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_to_gemini_with_tool_calls() {
|
||||||
|
let req: ChatCompletionsRequest = serde_json::from_value(json!({
|
||||||
|
"model": "gemini-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"location\": \"NYC\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "Sunny, 72F"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let gemini_req = GenerateContentRequest::try_from(req).unwrap();
|
||||||
|
assert_eq!(gemini_req.contents.len(), 3);
|
||||||
|
|
||||||
|
// Assistant with function_call
|
||||||
|
let model_content = &gemini_req.contents[1];
|
||||||
|
assert_eq!(model_content.role.as_deref(), Some("model"));
|
||||||
|
assert!(model_content.parts[0].function_call.is_some());
|
||||||
|
|
||||||
|
// Tool response
|
||||||
|
let tool_content = &gemini_req.contents[2];
|
||||||
|
assert_eq!(tool_content.role.as_deref(), Some("user"));
|
||||||
|
assert!(tool_content.parts[0].function_response.is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
417
crates/hermesllm/src/transforms/response/from_gemini.rs
Normal file
417
crates/hermesllm/src/transforms/response/from_gemini.rs
Normal file
|
|
@ -0,0 +1,417 @@
|
||||||
|
use crate::apis::anthropic::MessagesResponse;
|
||||||
|
use crate::apis::gemini::GenerateContentResponse;
|
||||||
|
use crate::apis::openai::{
|
||||||
|
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
|
||||||
|
FunctionCall as OpenAIFunctionCall, MessageDelta, ResponseMessage, Role, StreamChoice,
|
||||||
|
ToolCall as OpenAIToolCall, Usage,
|
||||||
|
};
|
||||||
|
use crate::clients::TransformError;
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsResponse
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn map_finish_reason(gemini_reason: Option<&str>) -> Option<FinishReason> {
|
||||||
|
gemini_reason.map(|r| match r {
|
||||||
|
"STOP" => FinishReason::Stop,
|
||||||
|
"MAX_TOKENS" => FinishReason::Length,
|
||||||
|
"SAFETY" | "RECITATION" => FinishReason::ContentFilter,
|
||||||
|
_ => FinishReason::Stop,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<GenerateContentResponse> for ChatCompletionsResponse {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||||
|
let candidates = resp.candidates.unwrap_or_default();
|
||||||
|
let candidate = candidates.first();
|
||||||
|
|
||||||
|
let mut content_text = String::new();
|
||||||
|
let mut tool_calls: Vec<OpenAIToolCall> = Vec::new();
|
||||||
|
|
||||||
|
if let Some(candidate) = candidate {
|
||||||
|
if let Some(ref content) = candidate.content {
|
||||||
|
for (i, part) in content.parts.iter().enumerate() {
|
||||||
|
if let Some(ref text) = part.text {
|
||||||
|
content_text.push_str(text);
|
||||||
|
}
|
||||||
|
if let Some(ref fc) = part.function_call {
|
||||||
|
tool_calls.push(OpenAIToolCall {
|
||||||
|
id: format!("call_{}", i),
|
||||||
|
call_type: "function".to_string(),
|
||||||
|
function: OpenAIFunctionCall {
|
||||||
|
name: fc.name.clone(),
|
||||||
|
arguments: serde_json::to_string(&fc.args).unwrap_or_default(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let finish_reason = candidate
|
||||||
|
.and_then(|c| map_finish_reason(c.finish_reason.as_deref()))
|
||||||
|
.unwrap_or(FinishReason::Stop);
|
||||||
|
|
||||||
|
let message_content = if content_text.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(content_text)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_calls_opt = if tool_calls.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tool_calls)
|
||||||
|
};
|
||||||
|
|
||||||
|
let choice = Choice {
|
||||||
|
index: 0,
|
||||||
|
message: ResponseMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: message_content,
|
||||||
|
tool_calls: tool_calls_opt,
|
||||||
|
refusal: None,
|
||||||
|
annotations: None,
|
||||||
|
audio: None,
|
||||||
|
function_call: None,
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let usage = resp
|
||||||
|
.usage_metadata
|
||||||
|
.map(|um| Usage {
|
||||||
|
prompt_tokens: um.prompt_token_count.unwrap_or(0),
|
||||||
|
completion_tokens: um.candidates_token_count.unwrap_or(0),
|
||||||
|
total_tokens: um.total_token_count.unwrap_or(0),
|
||||||
|
prompt_tokens_details: None,
|
||||||
|
completion_tokens_details: None,
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
Ok(ChatCompletionsResponse {
|
||||||
|
id: format!(
|
||||||
|
"chatcmpl-gemini-{}",
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis()
|
||||||
|
),
|
||||||
|
object: Some("chat.completion".to_string()),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
|
||||||
|
choices: vec![choice],
|
||||||
|
usage,
|
||||||
|
system_fingerprint: None,
|
||||||
|
service_tier: None,
|
||||||
|
metadata: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini GenerateContentResponse -> Anthropic MessagesResponse (via OpenAI)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<GenerateContentResponse> for MessagesResponse {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||||
|
// Chain: Gemini -> OpenAI -> Anthropic
|
||||||
|
let chat_resp = ChatCompletionsResponse::try_from(resp)?;
|
||||||
|
MessagesResponse::try_from(chat_resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Gemini GenerateContentResponse -> OpenAI ChatCompletionsStreamResponse
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<GenerateContentResponse> for ChatCompletionsStreamResponse {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(resp: GenerateContentResponse) -> Result<Self, Self::Error> {
|
||||||
|
let candidates = resp.candidates.unwrap_or_default();
|
||||||
|
let candidate = candidates.first();
|
||||||
|
|
||||||
|
let mut delta_content: Option<String> = None;
|
||||||
|
|
||||||
|
if let Some(candidate) = candidate {
|
||||||
|
if let Some(ref content) = candidate.content {
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
|
||||||
|
for part in content.parts.iter() {
|
||||||
|
if let Some(ref text) = part.text {
|
||||||
|
text_parts.push(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !text_parts.is_empty() {
|
||||||
|
delta_content = Some(text_parts.join(""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let finish_reason = candidate.and_then(|c| map_finish_reason(c.finish_reason.as_deref()));
|
||||||
|
|
||||||
|
let role = candidate
|
||||||
|
.and_then(|c| c.content.as_ref())
|
||||||
|
.and_then(|c| c.role.as_deref())
|
||||||
|
.map(|r| match r {
|
||||||
|
"model" => Role::Assistant,
|
||||||
|
_ => Role::User,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ChatCompletionsStreamResponse {
|
||||||
|
id: format!(
|
||||||
|
"chatcmpl-gemini-{}",
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis()
|
||||||
|
),
|
||||||
|
object: Some("chat.completion.chunk".to_string()),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
model: resp.model_version.unwrap_or_else(|| "gemini".to_string()),
|
||||||
|
choices: vec![StreamChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: MessageDelta {
|
||||||
|
role,
|
||||||
|
content: delta_content,
|
||||||
|
tool_calls: None,
|
||||||
|
refusal: None,
|
||||||
|
function_call: None,
|
||||||
|
},
|
||||||
|
finish_reason,
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
system_fingerprint: None,
|
||||||
|
service_tier: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// REVERSE: OpenAI ChatCompletionsResponse -> Gemini GenerateContentResponse
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
impl TryFrom<ChatCompletionsResponse> for GenerateContentResponse {
|
||||||
|
type Error = TransformError;
|
||||||
|
|
||||||
|
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||||
|
use crate::apis::gemini::{Candidate, Content, FunctionCall, Part, UsageMetadata};
|
||||||
|
|
||||||
|
let candidates = if let Some(choice) = resp.choices.first() {
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
|
||||||
|
// Text content
|
||||||
|
if let Some(ref content) = choice.message.content {
|
||||||
|
if !content.is_empty() {
|
||||||
|
parts.push(Part {
|
||||||
|
text: Some(content.clone()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool calls
|
||||||
|
if let Some(ref tool_calls) = choice.message.tool_calls {
|
||||||
|
for tc in tool_calls {
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||||
|
parts.push(Part {
|
||||||
|
text: None,
|
||||||
|
inline_data: None,
|
||||||
|
function_call: Some(FunctionCall {
|
||||||
|
name: tc.function.name.clone(),
|
||||||
|
args,
|
||||||
|
}),
|
||||||
|
function_response: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts.is_empty() {
|
||||||
|
parts.push(Part {
|
||||||
|
text: Some(String::new()),
|
||||||
|
inline_data: None,
|
||||||
|
function_call: None,
|
||||||
|
function_response: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let finish_reason = choice.finish_reason.as_ref().map(|fr| match fr {
|
||||||
|
FinishReason::Stop => "STOP".to_string(),
|
||||||
|
FinishReason::Length => "MAX_TOKENS".to_string(),
|
||||||
|
FinishReason::ContentFilter => "SAFETY".to_string(),
|
||||||
|
FinishReason::ToolCalls => "STOP".to_string(),
|
||||||
|
FinishReason::FunctionCall => "STOP".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
vec![Candidate {
|
||||||
|
content: Some(Content {
|
||||||
|
role: Some("model".to_string()),
|
||||||
|
parts,
|
||||||
|
}),
|
||||||
|
finish_reason,
|
||||||
|
safety_ratings: None,
|
||||||
|
}]
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
let usage_metadata = Some(UsageMetadata {
|
||||||
|
prompt_token_count: Some(resp.usage.prompt_tokens),
|
||||||
|
candidates_token_count: Some(resp.usage.completion_tokens),
|
||||||
|
total_token_count: Some(resp.usage.total_tokens),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(GenerateContentResponse {
|
||||||
|
candidates: Some(candidates),
|
||||||
|
usage_metadata,
|
||||||
|
prompt_feedback: None,
|
||||||
|
model_version: Some(resp.model),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// TESTS
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_response() {
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": "Hello! How can I help?"}]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP"
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 5,
|
||||||
|
"candidatesTokenCount": 7,
|
||||||
|
"totalTokenCount": 12
|
||||||
|
},
|
||||||
|
"modelVersion": "gemini-2.0-flash"
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
|
||||||
|
assert_eq!(openai_resp.choices.len(), 1);
|
||||||
|
let msg = &openai_resp.choices[0].message;
|
||||||
|
assert_eq!(msg.content.as_deref(), Some("Hello! How can I help?"));
|
||||||
|
assert_eq!(
|
||||||
|
openai_resp.choices[0].finish_reason,
|
||||||
|
Some(FinishReason::Stop)
|
||||||
|
);
|
||||||
|
assert_eq!(openai_resp.usage.prompt_tokens, 5);
|
||||||
|
assert_eq!(openai_resp.usage.completion_tokens, 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_stream_response() {
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": "Hello"}]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let stream_resp = ChatCompletionsStreamResponse::try_from(resp).unwrap();
|
||||||
|
assert_eq!(stream_resp.choices.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
stream_resp.choices[0].delta.content,
|
||||||
|
Some("Hello".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(stream_resp.choices[0].delta.role, Some(Role::Assistant));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gemini_to_openai_with_function_call() {
|
||||||
|
let resp: GenerateContentResponse = serde_json::from_value(json!({
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{
|
||||||
|
"functionCall": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"args": {"location": "NYC"}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP"
|
||||||
|
}]
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap();
|
||||||
|
let msg = &openai_resp.choices[0].message;
|
||||||
|
assert!(msg.tool_calls.is_some());
|
||||||
|
let tc = msg.tool_calls.as_ref().unwrap();
|
||||||
|
assert_eq!(tc[0].function.name, "get_weather");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_to_gemini_response() {
|
||||||
|
let resp: ChatCompletionsResponse = serde_json::from_value(json!({
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "gpt-4",
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "Hello!"},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let gemini_resp = GenerateContentResponse::try_from(resp).unwrap();
|
||||||
|
let candidates = gemini_resp.candidates.as_ref().unwrap();
|
||||||
|
assert_eq!(candidates.len(), 1);
|
||||||
|
let parts = &candidates[0].content.as_ref().unwrap().parts;
|
||||||
|
assert_eq!(parts[0].text.as_deref(), Some("Hello!"));
|
||||||
|
assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_finish_reason_mapping() {
|
||||||
|
assert_eq!(map_finish_reason(Some("STOP")), Some(FinishReason::Stop));
|
||||||
|
assert_eq!(
|
||||||
|
map_finish_reason(Some("MAX_TOKENS")),
|
||||||
|
Some(FinishReason::Length)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
map_finish_reason(Some("SAFETY")),
|
||||||
|
Some(FinishReason::ContentFilter)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
map_finish_reason(Some("RECITATION")),
|
||||||
|
Some(FinishReason::ContentFilter)
|
||||||
|
);
|
||||||
|
assert_eq!(map_finish_reason(None), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
//! Response transformation modules
|
//! Response transformation modules
|
||||||
|
pub mod from_gemini;
|
||||||
pub mod output_to_input;
|
pub mod output_to_input;
|
||||||
pub mod to_anthropic;
|
pub mod to_anthropic;
|
||||||
pub mod to_openai;
|
pub mod to_openai;
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,9 @@ impl StreamContext {
|
||||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
||||||
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
||||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_)
|
||||||
|
| SupportedUpstreamAPIs::GeminiGenerateContent(_)
|
||||||
|
| SupportedUpstreamAPIs::GeminiStreamGenerateContent(_),
|
||||||
)
|
)
|
||||||
| None => {
|
| None => {
|
||||||
// OpenAI and default: use Authorization Bearer token
|
// OpenAI and default: use Authorization Bearer token
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue