From d9e3d0b2bb7ed94ff0532199a514753ddabc4353 Mon Sep 17 00:00:00 2001 From: Valerio Date: Tue, 16 Jun 2026 15:27:56 +0200 Subject: [PATCH] feat(llm): add Gemini provider and fix stale Anthropic default model Adds a Google Gemini provider (Generative Language API) to the chain, ordered Ollama -> OpenAI -> Gemini -> Anthropic so Google credits are preferred with Anthropic as last-resort fallback. System->systemInstruction, assistant->model, json_mode->responseMimeType; model name validated before URL interpolation; maxOutputTokens defaults high for 2.5 thinking models. Also fixes AnthropicProvider default (retired claude-sonnet-4-20250514 -> 404); now claude-sonnet-4-6, honors ANTHROPIC_MODEL. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/webclaw-llm/src/chain.rs | 14 +- crates/webclaw-llm/src/lib.rs | 2 +- crates/webclaw-llm/src/providers/anthropic.rs | 8 +- crates/webclaw-llm/src/providers/gemini.rs | 363 ++++++++++++++++++ crates/webclaw-llm/src/providers/mod.rs | 1 + 5 files changed, 381 insertions(+), 7 deletions(-) create mode 100644 crates/webclaw-llm/src/providers/gemini.rs diff --git a/crates/webclaw-llm/src/chain.rs b/crates/webclaw-llm/src/chain.rs index 86b0101..e2c6b8b 100644 --- a/crates/webclaw-llm/src/chain.rs +++ b/crates/webclaw-llm/src/chain.rs @@ -1,5 +1,5 @@ /// Provider chain — tries providers in order until one succeeds. -/// Default order: Ollama (local, free) -> OpenAI -> Anthropic. +/// Default order: Ollama (local, free) -> OpenAI -> Gemini -> Anthropic. /// Only includes providers that are actually configured/available. use async_trait::async_trait; use tracing::{debug, warn}; @@ -7,7 +7,8 @@ use tracing::{debug, warn}; use crate::error::LlmError; use crate::provider::{CompletionRequest, LlmProvider}; use crate::providers::{ - anthropic::AnthropicProvider, ollama::OllamaProvider, openai::OpenAiProvider, + anthropic::AnthropicProvider, gemini::GeminiProvider, ollama::OllamaProvider, + openai::OpenAiProvider, }; pub struct ProviderChain { @@ -15,9 +16,11 @@ pub struct ProviderChain { } impl ProviderChain { - /// Build the default chain: Ollama -> OpenAI -> Anthropic. + /// Build the default chain: Ollama -> OpenAI -> Gemini -> Anthropic. /// Ollama is always added (availability checked at call time). /// Cloud providers are only added if their API keys are configured. + /// Gemini sits ahead of Anthropic so Google Cloud credits are preferred, + /// with Anthropic as the last-resort fallback. pub async fn default() -> Self { let mut providers: Vec> = Vec::new(); @@ -34,6 +37,11 @@ impl ProviderChain { providers.push(Box::new(openai)); } + if let Some(gemini) = GeminiProvider::new(None, None, None) { + debug!("gemini configured, adding to chain"); + providers.push(Box::new(gemini)); + } + if let Some(anthropic) = AnthropicProvider::with_base_url(None, None, None) { debug!("anthropic configured, adding to chain"); providers.push(Box::new(anthropic)); diff --git a/crates/webclaw-llm/src/lib.rs b/crates/webclaw-llm/src/lib.rs index 61e2ae7..30fc44c 100644 --- a/crates/webclaw-llm/src/lib.rs +++ b/crates/webclaw-llm/src/lib.rs @@ -1,6 +1,6 @@ /// webclaw-llm: LLM integration with local-first hybrid architecture. /// -/// Provider chain tries Ollama (local) first, falls back to OpenAI, then Anthropic. +/// Provider chain tries Ollama (local) first, falls back to OpenAI, then Gemini, then Anthropic. /// Provides schema-based extraction, prompt extraction, and summarization /// on top of webclaw-core's content pipeline. pub mod chain; diff --git a/crates/webclaw-llm/src/providers/anthropic.rs b/crates/webclaw-llm/src/providers/anthropic.rs index eb15973..c33d7f5 100644 --- a/crates/webclaw-llm/src/providers/anthropic.rs +++ b/crates/webclaw-llm/src/providers/anthropic.rs @@ -48,7 +48,9 @@ impl AnthropicProvider { .unwrap_or_else(|| DEFAULT_ANTHROPIC_BASE_URL.into()) .trim_end_matches('/') .to_string(), - default_model: model.unwrap_or_else(|| "claude-sonnet-4-20250514".into()), + default_model: model + .or_else(|| std::env::var("ANTHROPIC_MODEL").ok()) + .unwrap_or_else(|| "claude-sonnet-4-6".into()), }) } @@ -158,7 +160,7 @@ mod tests { let provider = AnthropicProvider::new(Some("sk-ant-test".into()), None).expect("should construct"); assert_eq!(provider.name(), "anthropic"); - assert_eq!(provider.default_model, "claude-sonnet-4-20250514"); + assert_eq!(provider.default_model, "claude-sonnet-4-6"); assert_eq!(provider.key, "sk-ant-test"); assert_eq!(provider.base_url, "https://api.anthropic.com/v1"); assert_eq!( @@ -178,7 +180,7 @@ mod tests { #[test] fn default_model_accessor() { let provider = AnthropicProvider::new(Some("sk-ant-test".into()), None).unwrap(); - assert_eq!(provider.default_model(), "claude-sonnet-4-20250514"); + assert_eq!(provider.default_model(), "claude-sonnet-4-6"); } #[test] diff --git a/crates/webclaw-llm/src/providers/gemini.rs b/crates/webclaw-llm/src/providers/gemini.rs new file mode 100644 index 0000000..6b77869 --- /dev/null +++ b/crates/webclaw-llm/src/providers/gemini.rs @@ -0,0 +1,363 @@ +/// Google Gemini provider — Gemini models via the Generative Language API. +/// Gemini's request shape differs from OpenAI/Anthropic: the system message is a +/// top-level `systemInstruction`, conversation turns live in `contents` (with the +/// assistant role renamed to `model`), and generation knobs sit under +/// `generationConfig`. API-key auth is sent as an `x-goog-api-key` header. +use std::time::Duration; + +use async_trait::async_trait; +use serde_json::json; + +use crate::clean::strip_thinking_tags; +use crate::error::LlmError; +use crate::provider::{CompletionRequest, LlmProvider}; + +use super::load_api_key; + +const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta"; +/// Default model. Gemini 2.5 Flash/Pro are "thinking" models: internal reasoning +/// tokens count against `maxOutputTokens`, so the output budget must comfortably +/// exceed the visible response (see `request_body`) or the model returns +/// `finishReason=MAX_TOKENS` with no text. Set `GEMINI_MODEL` to a non-thinking +/// model (e.g. `gemini-2.0-flash`) to avoid the reasoning overhead entirely. +const DEFAULT_GEMINI_MODEL: &str = "gemini-2.5-flash"; + +/// Gemini puts the model in the URL path, so only plain model identifiers are +/// safe to interpolate. Real model names are ASCII alphanumerics plus `-`/`.`/`_` +/// (e.g. `gemini-2.5-flash`, `gemini-2.0-flash-001`); anything else (`/`, `:`, +/// `?`, `#`, whitespace) could redirect the request to a different path/method. +fn is_safe_model_name(model: &str) -> bool { + !model.is_empty() + && model + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_')) +} + +pub struct GeminiProvider { + client: reqwest::Client, + key: String, + base_url: String, + default_model: String, +} + +impl GeminiProvider { + /// Returns `None` if no API key is available (param or `GEMINI_API_KEY` env). + pub fn new( + key_override: Option, + base_url: Option, + model: Option, + ) -> Option { + let key = load_api_key(key_override, "GEMINI_API_KEY")?; + + Some(Self { + client: reqwest::Client::builder() + .timeout(Duration::from_secs(120)) + .connect_timeout(Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()), + key, + base_url: base_url + .or_else(|| std::env::var("GEMINI_BASE_URL").ok()) + .unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.into()) + .trim_end_matches('/') + .to_string(), + default_model: model + .or_else(|| std::env::var("GEMINI_MODEL").ok()) + .unwrap_or_else(|| DEFAULT_GEMINI_MODEL.into()), + }) + } + + pub fn default_model(&self) -> &str { + &self.default_model + } + + /// Build the `generateContent` body from a generic completion request. + /// System messages become `systemInstruction`; user/assistant turns become + /// `contents` (assistant → `model`); `json_mode` constrains the model to + /// valid JSON via `responseMimeType`. + fn request_body(&self, request: &CompletionRequest) -> serde_json::Value { + let contents: Vec = request + .messages + .iter() + .filter(|m| m.role != "system") + .map(|m| { + let role = if m.role == "assistant" { + "model" + } else { + "user" + }; + json!({ "role": role, "parts": [{ "text": m.content }] }) + }) + .collect(); + + let system_parts: Vec = request + .messages + .iter() + .filter(|m| m.role == "system") + .map(|m| json!({ "text": m.content })) + .collect(); + + // `maxOutputTokens` is a ceiling, not a reservation — you're billed per + // token actually produced — so default generously. Gemini 2.5 "thinking" + // models spend part of this budget on internal reasoning; too low a cap + // makes them return `finishReason=MAX_TOKENS` with no visible text. + let mut generation_config = json!({ + "maxOutputTokens": request.max_tokens.unwrap_or(8192), + }); + if let Some(temp) = request.temperature { + generation_config["temperature"] = json!(temp); + } + if request.json_mode { + generation_config["responseMimeType"] = json!("application/json"); + } + + let mut body = json!({ + "contents": contents, + "generationConfig": generation_config, + }); + + // Gemini rejects an empty `systemInstruction`, so only attach it when a + // system message is actually present. + if !system_parts.is_empty() { + body["systemInstruction"] = json!({ "parts": system_parts }); + } + + body + } +} + +#[async_trait] +impl LlmProvider for GeminiProvider { + async fn complete(&self, request: &CompletionRequest) -> Result { + let model = if request.model.is_empty() { + &self.default_model + } else { + &request.model + }; + + // The model goes in the URL path (Gemini's API requires it there, unlike + // OpenAI/Anthropic which take it in the body), so reject anything that + // isn't a plain model identifier to prevent path/query injection from a + // caller-supplied `request.model`. + if !is_safe_model_name(model) { + return Err(LlmError::ProviderError(format!( + "invalid gemini model name: {model:?}" + ))); + } + + let body = self.request_body(request); + + // API-key auth goes in the header, never the URL, so the key can't leak + // into request logs, proxies, or referrer headers. + let url = format!("{}/models/{model}:generateContent", self.base_url); + let resp = self + .client + .post(&url) + .header("x-goog-api-key", &self.key) + .header("content-type", "application/json") + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + let safe_text = text.chars().take(500).collect::(); + return Err(LlmError::ProviderError(format!( + "gemini returned {status}: {safe_text}" + ))); + } + + // Cap response body size to defend against adversarial payloads. + let json = super::response_json_capped(resp).await?; + + // Gemini response: {"candidates":[{"content":{"parts":[{"text":"..."}]}}]}. + // A candidate may carry multiple text parts; concatenate them in order. + let text = json["candidates"][0]["content"]["parts"] + .as_array() + .map(|parts| { + parts + .iter() + .filter_map(|p| p["text"].as_str()) + .collect::() + }) + .unwrap_or_default(); + + if text.is_empty() { + // No usable text. Surface Gemini's finishReason (or a prompt-level + // block reason) so MAX_TOKENS — e.g. a "thinking" model that spent + // its whole maxOutputTokens budget on reasoning — and SAFETY blocks + // are visible in logs/telemetry instead of masquerading as a parse + // failure. The chain falls through to the next provider on any Err. + let reason = json["candidates"][0]["finishReason"] + .as_str() + .or_else(|| json["promptFeedback"]["blockReason"].as_str()) + .unwrap_or("unknown"); + return Err(LlmError::ProviderError(format!( + "gemini returned no text (finishReason={reason})" + ))); + } + + Ok(strip_thinking_tags(&text)) + } + + async fn is_available(&self) -> bool { + !self.key.is_empty() + } + + fn name(&self) -> &str { + "gemini" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::Message; + + fn provider() -> GeminiProvider { + GeminiProvider::new(Some("test-key".into()), None, None).expect("should construct") + } + + fn msg(role: &str, content: &str) -> Message { + Message { + role: role.into(), + content: content.into(), + } + } + + fn request(messages: Vec, json_mode: bool) -> CompletionRequest { + CompletionRequest { + model: String::new(), + messages, + temperature: None, + max_tokens: None, + json_mode, + } + } + + #[test] + fn empty_key_returns_none() { + assert!(GeminiProvider::new(Some(String::new()), None, None).is_none()); + } + + #[test] + fn model_name_validation_blocks_path_injection() { + // Real model identifiers pass. + assert!(is_safe_model_name("gemini-2.5-flash")); + assert!(is_safe_model_name("gemini-2.0-flash-001")); + assert!(is_safe_model_name("gemini-1.5-pro-002")); + // Anything that could alter the request path/method is rejected. + assert!(!is_safe_model_name("")); + assert!(!is_safe_model_name( + "gemini-2.5-flash:streamGenerateContent" + )); + assert!(!is_safe_model_name("../../models/x")); + assert!(!is_safe_model_name("model?alt=sse")); + assert!(!is_safe_model_name("a b")); + } + + #[test] + fn explicit_key_constructs_with_defaults() { + let p = provider(); + assert_eq!(p.name(), "gemini"); + assert_eq!(p.key, "test-key"); + assert_eq!(p.default_model, DEFAULT_GEMINI_MODEL); + assert_eq!(p.default_model(), DEFAULT_GEMINI_MODEL); + assert_eq!(p.base_url, DEFAULT_GEMINI_BASE_URL); + } + + #[test] + fn custom_base_url_trims_trailing_slash_and_model() { + let p = GeminiProvider::new( + Some("test-key".into()), + Some("https://example.test/v1beta/".into()), + Some("gemini-2.5-pro".into()), + ) + .unwrap(); + assert_eq!(p.base_url, "https://example.test/v1beta"); + assert_eq!(p.default_model, "gemini-2.5-pro"); + } + + #[test] + fn maps_user_and_assistant_roles_into_contents() { + let p = provider(); + let body = p.request_body(&request( + vec![msg("user", "hello"), msg("assistant", "hi there")], + false, + )); + let contents = body["contents"].as_array().unwrap(); + assert_eq!(contents.len(), 2); + assert_eq!(contents[0]["role"], "user"); + assert_eq!(contents[0]["parts"][0]["text"], "hello"); + // assistant must be renamed to Gemini's "model" role. + assert_eq!(contents[1]["role"], "model"); + assert_eq!(contents[1]["parts"][0]["text"], "hi there"); + // No system message -> no systemInstruction key at all. + assert!(body.get("systemInstruction").is_none()); + } + + #[test] + fn system_message_becomes_system_instruction_not_contents() { + let p = provider(); + let body = p.request_body(&request( + vec![msg("system", "be terse"), msg("user", "hello")], + false, + )); + let contents = body["contents"].as_array().unwrap(); + assert_eq!(contents.len(), 1, "system message lifted out of contents"); + assert_eq!(contents[0]["role"], "user"); + assert_eq!(body["systemInstruction"]["parts"][0]["text"], "be terse"); + } + + #[test] + fn json_mode_toggles_response_mime_type() { + let p = provider(); + let on = p.request_body(&request(vec![msg("user", "x")], true)); + assert_eq!( + on["generationConfig"]["responseMimeType"], + "application/json" + ); + let off = p.request_body(&request(vec![msg("user", "x")], false)); + assert!(off["generationConfig"].get("responseMimeType").is_none()); + } + + #[test] + fn max_output_tokens_default_and_temperature_override() { + let p = provider(); + let default_body = p.request_body(&request(vec![msg("user", "x")], false)); + assert_eq!(default_body["generationConfig"]["maxOutputTokens"], 8192); + // No temperature set -> key omitted. + assert!( + default_body["generationConfig"] + .get("temperature") + .is_none() + ); + + let mut req = request(vec![msg("user", "x")], false); + req.max_tokens = Some(256); + req.temperature = Some(0.5); // 0.5 is exact in both f32 and f64 + let body = p.request_body(&req); + assert_eq!(body["generationConfig"]["maxOutputTokens"], 256); + assert_eq!(body["generationConfig"]["temperature"], 0.5); + } + + // Env var fallback tests mutate process-global state and race with parallel + // tests. Run in isolation if needed: + // cargo test -p webclaw-llm env_var -- --ignored --test-threads=1 + #[test] + #[ignore = "mutates process env; run with --test-threads=1"] + fn env_var_key_fallback() { + unsafe { std::env::set_var("GEMINI_API_KEY", "gemini-env-key") }; + let p = GeminiProvider::new(None, None, None).expect("should construct from env"); + assert_eq!(p.key, "gemini-env-key"); + unsafe { std::env::remove_var("GEMINI_API_KEY") }; + } + + #[test] + #[ignore = "mutates process env; run with --test-threads=1"] + fn no_key_returns_none() { + unsafe { std::env::remove_var("GEMINI_API_KEY") }; + assert!(GeminiProvider::new(None, None, None).is_none()); + } +} diff --git a/crates/webclaw-llm/src/providers/mod.rs b/crates/webclaw-llm/src/providers/mod.rs index 1e6412b..d6ae34a 100644 --- a/crates/webclaw-llm/src/providers/mod.rs +++ b/crates/webclaw-llm/src/providers/mod.rs @@ -1,4 +1,5 @@ pub mod anthropic; +pub mod gemini; pub mod ollama; pub mod openai;