fix: support LLM provider compatibility options

Closes #36
This commit is contained in:
Valerio 2026-05-06 11:36:53 +02:00
parent 86183b11e4
commit a3aa4bce6f
6 changed files with 193 additions and 16 deletions

View file

@ -34,7 +34,7 @@ impl ProviderChain {
providers.push(Box::new(openai));
}
if let Some(anthropic) = AnthropicProvider::new(None, None) {
if let Some(anthropic) = AnthropicProvider::with_base_url(None, None, None) {
debug!("anthropic configured, adding to chain");
providers.push(Box::new(anthropic));
}

View file

@ -10,23 +10,38 @@ use crate::provider::{CompletionRequest, LlmProvider};
use super::load_api_key;
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const DEFAULT_ANTHROPIC_BASE_URL: &str = "https://api.anthropic.com/v1";
const ANTHROPIC_VERSION: &str = "2023-06-01";
pub struct AnthropicProvider {
client: reqwest::Client,
key: String,
base_url: String,
default_model: String,
}
impl AnthropicProvider {
/// Returns `None` if no API key is available (param or env).
pub fn new(key_override: Option<String>, model: Option<String>) -> Option<Self> {
Self::with_base_url(key_override, None, model)
}
/// Returns `None` if no API key is available (param or env).
pub fn with_base_url(
key_override: Option<String>,
base_url: Option<String>,
model: Option<String>,
) -> Option<Self> {
let key = load_api_key(key_override, "ANTHROPIC_API_KEY")?;
Some(Self {
client: reqwest::Client::new(),
key,
base_url: base_url
.or_else(|| std::env::var("ANTHROPIC_BASE_URL").ok())
.unwrap_or_else(|| DEFAULT_ANTHROPIC_BASE_URL.into())
.trim_end_matches('/')
.to_string(),
default_model: model.unwrap_or_else(|| "claude-sonnet-4-20250514".into()),
})
}
@ -34,6 +49,14 @@ impl AnthropicProvider {
pub fn default_model(&self) -> &str {
&self.default_model
}
fn messages_url(&self) -> String {
if self.base_url.ends_with("/messages") {
self.base_url.clone()
} else {
format!("{}/messages", self.base_url)
}
}
}
#[async_trait]
@ -74,7 +97,7 @@ impl LlmProvider for AnthropicProvider {
let resp = self
.client
.post(ANTHROPIC_API_URL)
.post(self.messages_url())
.header("x-api-key", &self.key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
@ -135,6 +158,11 @@ mod tests {
assert_eq!(provider.name(), "anthropic");
assert_eq!(provider.default_model, "claude-sonnet-4-20250514");
assert_eq!(provider.key, "sk-ant-test");
assert_eq!(provider.base_url, "https://api.anthropic.com/v1");
assert_eq!(
provider.messages_url(),
"https://api.anthropic.com/v1/messages"
);
}
#[test]
@ -151,6 +179,35 @@ mod tests {
assert_eq!(provider.default_model(), "claude-sonnet-4-20250514");
}
#[test]
fn custom_base_url_appends_messages_path() {
let provider = AnthropicProvider::with_base_url(
Some("sk-ant-test".into()),
Some("https://proxy.example.test/anthropic/v1/".into()),
None,
)
.unwrap();
assert_eq!(provider.base_url, "https://proxy.example.test/anthropic/v1");
assert_eq!(
provider.messages_url(),
"https://proxy.example.test/anthropic/v1/messages"
);
}
#[test]
fn custom_full_messages_url_is_not_doubled() {
let provider = AnthropicProvider::with_base_url(
Some("sk-ant-test".into()),
Some("https://proxy.example.test/v1/messages".into()),
None,
)
.unwrap();
assert_eq!(
provider.messages_url(),
"https://proxy.example.test/v1/messages"
);
}
// Env var fallback tests mutate process-global state and race with parallel tests.
// The code path is trivial (load_api_key -> env::var().ok()). Run in isolation if needed:
// cargo test -p webclaw-llm env_var -- --ignored --test-threads=1

View file

@ -13,6 +13,50 @@ pub struct OpenAiProvider {
key: String,
base_url: String,
default_model: String,
response_format: OpenAiResponseFormat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OpenAiResponseFormat {
JsonObject,
JsonSchema,
Text,
}
impl OpenAiResponseFormat {
fn from_env() -> Self {
std::env::var("OPENAI_RESPONSE_FORMAT_TYPE")
.ok()
.and_then(|value| Self::parse(&value))
.unwrap_or(Self::JsonObject)
}
fn parse(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"" | "json_object" => Some(Self::JsonObject),
"json_schema" => Some(Self::JsonSchema),
"text" => Some(Self::Text),
_ => None,
}
}
fn as_response_format(self) -> serde_json::Value {
match self {
Self::JsonObject => json!({ "type": "json_object" }),
Self::JsonSchema => json!({
"type": "json_schema",
"json_schema": {
"name": "webclaw_response",
"schema": {
"type": "object",
"additionalProperties": true
},
"strict": false
}
}),
Self::Text => json!({ "type": "text" }),
}
}
}
impl OpenAiProvider {
@ -31,23 +75,15 @@ impl OpenAiProvider {
.or_else(|| std::env::var("OPENAI_BASE_URL").ok())
.unwrap_or_else(|| "https://api.openai.com/v1".into()),
default_model: model.unwrap_or_else(|| "gpt-4o-mini".into()),
response_format: OpenAiResponseFormat::from_env(),
})
}
pub fn default_model(&self) -> &str {
&self.default_model
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
async fn complete(&self, request: &CompletionRequest) -> Result<String, LlmError> {
let model = if request.model.is_empty() {
&self.default_model
} else {
&request.model
};
fn request_body(&self, request: &CompletionRequest, model: &str) -> serde_json::Value {
let messages: Vec<serde_json::Value> = request
.messages
.iter()
@ -60,7 +96,7 @@ impl LlmProvider for OpenAiProvider {
});
if request.json_mode {
body["response_format"] = json!({ "type": "json_object" });
body["response_format"] = self.response_format.as_response_format();
}
if let Some(temp) = request.temperature {
body["temperature"] = json!(temp);
@ -69,6 +105,21 @@ impl LlmProvider for OpenAiProvider {
body["max_tokens"] = json!(max);
}
body
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
async fn complete(&self, request: &CompletionRequest) -> Result<String, LlmError> {
let model = if request.model.is_empty() {
&self.default_model
} else {
&request.model
};
let body = self.request_body(request, model);
let url = format!("{}/chat/completions", self.base_url);
let resp = self
.client
@ -136,6 +187,7 @@ mod tests {
assert_eq!(provider.default_model, "gpt-4o-mini");
assert_eq!(provider.base_url, "https://api.openai.com/v1");
assert_eq!(provider.key, "test-key-123");
assert_eq!(provider.response_format, OpenAiResponseFormat::JsonObject);
}
#[test]
@ -161,6 +213,69 @@ mod tests {
assert_eq!(provider.default_model(), "gpt-4o-mini");
}
#[test]
fn json_mode_defaults_to_openai_json_object() {
let provider = OpenAiProvider::new(
Some("test-key".into()),
Some("https://api.openai.com/v1".into()),
None,
)
.unwrap();
let req = CompletionRequest {
model: String::new(),
messages: vec![],
temperature: None,
max_tokens: None,
json_mode: true,
};
let body = provider.request_body(&req, provider.default_model());
assert_eq!(body["response_format"], json!({ "type": "json_object" }));
}
#[test]
fn json_schema_response_format_for_compatible_backends() {
let mut provider = OpenAiProvider::new(
Some("test-key".into()),
Some("http://localhost:1234/v1".into()),
Some("local-model".into()),
)
.unwrap();
provider.response_format = OpenAiResponseFormat::JsonSchema;
let req = CompletionRequest {
model: String::new(),
messages: vec![],
temperature: None,
max_tokens: None,
json_mode: true,
};
let body = provider.request_body(&req, provider.default_model());
assert_eq!(body["response_format"]["type"], "json_schema");
assert_eq!(
body["response_format"]["json_schema"]["schema"]["type"],
"object"
);
}
#[test]
fn text_response_format_for_lm_studio() {
let mut provider = OpenAiProvider::new(
Some("test-key".into()),
Some("http://localhost:1234/v1".into()),
Some("local-model".into()),
)
.unwrap();
provider.response_format = OpenAiResponseFormat::Text;
let req = CompletionRequest {
model: String::new(),
messages: vec![],
temperature: None,
max_tokens: None,
json_mode: true,
};
let body = provider.request_body(&req, provider.default_model());
assert_eq!(body["response_format"], json!({ "type": "text" }));
}
// Env var fallback tests mutate process-global state and race with parallel tests.
// The code path is trivial (load_api_key -> env::var().ok()). Run in isolation if needed:
// cargo test -p webclaw-llm env_var -- --ignored --test-threads=1