diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index a1facae8..ead0a351 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -8,7 +8,14 @@ from urllib.parse import urlparse from copy import deepcopy -SUPPORTED_PROVIDERS = [ +SUPPORTED_PROVIDERS_WITH_BASE_URL = [ + "azure_openai", + "ollama", + "qwen", + "amazon_bedrock", +] + +SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [ "arch", "deepseek", "groq", @@ -17,15 +24,15 @@ SUPPORTED_PROVIDERS = [ "gemini", "anthropic", "together_ai", - "azure_openai", "xai", - "ollama", "moonshotai", "zhipu", - "qwen", - "amazon_bedrock", ] +SUPPORTED_PROVIDERS = ( + SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL +) + def get_endpoint_and_port(endpoint, protocol): endpoint_tokens = endpoint.split(":") @@ -189,12 +196,9 @@ def validate_and_render_schema(): provider = model_name_tokens[0] # Validate azure_openai and ollama provider requires base_url - if ( - provider == "azure_openai" - or provider == "ollama" - or provider == "qwen" - or provider == "amazon_bedrock" - ) and model_provider.get("base_url") is None: + if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get( + "base_url" + ) is None: raise Exception( f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" ) @@ -245,11 +249,11 @@ def validate_and_render_schema(): if model_provider.get("base_url", None): base_url = model_provider["base_url"] urlparse_result = urlparse(base_url) - url_path = urlparse_result.path - if url_path and url_path != "/": - raise Exception( - f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'." - ) + base_url_path_prefix = urlparse_result.path + if base_url_path_prefix and base_url_path_prefix != "/": + # we will now support base_url_path_prefix. This means that the user can provide base_url like http://example.com/path and we will extract /path as base_url_path_prefix + model_provider["base_url_path_prefix"] = base_url_path_prefix + if urlparse_result.scheme == "" or urlparse_result.scheme not in [ "http", "https", diff --git a/arch/tools/test/test_config_generator.py b/arch/tools/test/test_config_generator.py index 0d8f69b9..7016a34f 100644 --- a/arch/tools/test/test_config_generator.py +++ b/arch/tools/test/test_config_generator.py @@ -243,14 +243,13 @@ listeners: timeout: 30s llm_providers: - - model: custom/gpt-4o """, }, { - "id": "base_url_no_prefix", - "expected_error": "Please provide base_url without path", + "id": "base_url_with_path_prefix", + "expected_error": None, "arch_config": """ version: v0.1.0 @@ -264,7 +263,7 @@ listeners: llm_providers: - model: custom/gpt-4o - base_url: "http://custom.com/test" + base_url: "http://custom.com/api/v2" provider_interface: openai """, @@ -322,8 +321,7 @@ def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case): monkeypatch.setenv("TEMPLATE_ROOT", "../") arch_config = arch_config_test_case["arch_config"] - expected_error = arch_config_test_case["expected_error"] - test_id = arch_config_test_case["id"] + expected_error = arch_config_test_case.get("expected_error") arch_config_schema = "" with open("../arch_config_schema.yaml", "r") as file: @@ -346,9 +344,14 @@ def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case): ] with mock.patch("builtins.open", m_open): with mock.patch("config_generator.Environment"): - with pytest.raises(Exception) as excinfo: + if expected_error: + # Test expects an error + with pytest.raises(Exception) as excinfo: + validate_and_render_schema() + assert expected_error in str(excinfo.value) + else: + # Test expects success - no exception should be raised validate_and_render_schema() - assert expected_error in str(excinfo.value) def test_convert_legacy_llm_providers(): diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index dc1b74e9..27f8ebd9 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -267,6 +267,7 @@ pub struct LlmProvider { pub usage: Option, pub routing_preferences: Option>, pub cluster_name: Option, + pub base_url_path_prefix: Option, } pub trait IntoModels { @@ -307,6 +308,7 @@ impl Default for LlmProvider { usage: None, routing_preferences: None, cluster_name: None, + base_url_path_prefix: None, } } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 5177ce97..e0ad47d3 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -77,70 +77,85 @@ impl SupportedAPIs { request_path: &str, model_id: &str, is_streaming: bool, + base_url_path_prefix: Option<&str>, ) -> String { - let default_endpoint = "/v1/chat/completions".to_string(); + // Helper function to build endpoint with optional prefix override + let build_endpoint = |provider_prefix: &str, suffix: &str| -> String { + let prefix = base_url_path_prefix + .map(|p| p.trim_matches('/')) + .filter(|p| !p.is_empty()) + .unwrap_or(provider_prefix.trim_matches('/')); + + let suffix = suffix.trim_start_matches('/'); + if prefix.is_empty() { + format!("/{}", suffix) + } else { + format!("/{}/{}", prefix, suffix) + } + }; + match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { - ProviderId::Anthropic => "/v1/messages".to_string(), + ProviderId::Anthropic => build_endpoint("/v1", "/messages"), ProviderId::AmazonBedrock => { if request_path.starts_with("/v1/") && !is_streaming { - format!("/model/{}/converse", model_id) + build_endpoint("", &format!("/model/{}/converse", model_id)) } else if request_path.starts_with("/v1/") && is_streaming { - format!("/model/{}/converse-stream", model_id) + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } - _ => default_endpoint, + _ => build_endpoint("/v1", "/chat/completions"), }, _ => match provider_id { ProviderId::Groq => { if request_path.starts_with("/v1/") { - format!("/openai{}", request_path) + build_endpoint("/openai", request_path) } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } ProviderId::Zhipu => { if request_path.starts_with("/v1/") { - "/api/paas/v4/chat/completions".to_string() + build_endpoint("/api/paas/v4", "/chat/completions") } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } ProviderId::Qwen => { if request_path.starts_with("/v1/") { - "/compatible-mode/v1/chat/completions".to_string() + build_endpoint("/compatible-mode/v1", "/chat/completions") } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } ProviderId::AzureOpenAI => { if request_path.starts_with("/v1/") { - format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id) + build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id)) } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } ProviderId::Gemini => { if request_path.starts_with("/v1/") { - "/v1beta/openai/chat/completions".to_string() + build_endpoint("/v1beta/openai", "/chat/completions") } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } ProviderId::AmazonBedrock => { if request_path.starts_with("/v1/") { if !is_streaming { - format!("/model/{}/converse", model_id) + build_endpoint("", &format!("/model/{}/converse", model_id)) } else { - format!("/model/{}/converse-stream", model_id) + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) } } else { - default_endpoint + build_endpoint("/v1", "/chat/completions") } } - _ => default_endpoint, + _ => build_endpoint("/v1", "/chat/completions"), }, } } @@ -245,4 +260,327 @@ mod tests { OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len() ); } + + #[test] + fn test_target_endpoint_without_base_url_prefix() { + let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Test default OpenAI provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::OpenAI, + "/v1/chat/completions", + "gpt-4", + false, + None + ), + "/v1/chat/completions" + ); + + // Test Groq provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Groq, + "/v1/chat/completions", + "llama2", + false, + None + ), + "/openai/v1/chat/completions" + ); + + // Test Zhipu provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + None + ), + "/api/paas/v4/chat/completions" + ); + + // Test Qwen provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Qwen, + "/v1/chat/completions", + "qwen-turbo", + false, + None + ), + "/compatible-mode/v1/chat/completions" + ); + + // Test Azure OpenAI provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AzureOpenAI, + "/v1/chat/completions", + "gpt-4", + false, + None + ), + "/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview" + ); + + // Test Gemini provider + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Gemini, + "/v1/chat/completions", + "gemini-pro", + false, + None + ), + "/v1beta/openai/chat/completions" + ); + } + + #[test] + fn test_target_endpoint_with_base_url_prefix() { + let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Test Zhipu with custom base_url_path_prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + Some("/api/coding/paas/v4") + ), + "/api/coding/paas/v4/chat/completions" + ); + + // Test with prefix without leading slash + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + Some("api/coding/paas/v4") + ), + "/api/coding/paas/v4/chat/completions" + ); + + // Test with prefix with trailing slash + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + Some("/api/coding/paas/v4/") + ), + "/api/coding/paas/v4/chat/completions" + ); + + // Test OpenAI with custom prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::OpenAI, + "/v1/chat/completions", + "gpt-4", + false, + Some("/custom/api/v2") + ), + "/custom/api/v2/chat/completions" + ); + + // Test Groq with custom prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Groq, + "/v1/chat/completions", + "llama2", + false, + Some("/api/v2") + ), + "/api/v2/v1/chat/completions" + ); + } + + #[test] + fn test_target_endpoint_with_empty_base_url_prefix() { + let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Test with just slashes - trims to empty, uses provider default + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + Some("/") + ), + "/api/paas/v4/chat/completions" + ); + + // Test with None - uses provider default + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/v1/chat/completions", + "chatglm", + false, + None + ), + "/api/paas/v4/chat/completions" + ); + } + + #[test] + fn test_amazon_bedrock_endpoints() { + let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Test Bedrock non-streaming without prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AmazonBedrock, + "/v1/messages", + "us.amazon.nova-pro-v1:0", + false, + None + ), + "/model/us.amazon.nova-pro-v1:0/converse" + ); + + // Test Bedrock streaming without prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AmazonBedrock, + "/v1/messages", + "us.amazon.nova-pro-v1:0", + true, + None + ), + "/model/us.amazon.nova-pro-v1:0/converse-stream" + ); + + // Test Bedrock non-streaming with prefix (prefix shouldn't affect bedrock) + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AmazonBedrock, + "/v1/messages", + "us.amazon.nova-pro-v1:0", + false, + Some("/custom/path") + ), + "/custom/path/model/us.amazon.nova-pro-v1:0/converse" + ); + + // Test Bedrock streaming with prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AmazonBedrock, + "/v1/messages", + "us.amazon.nova-pro-v1:0", + true, + Some("/custom/path") + ), + "/custom/path/model/us.amazon.nova-pro-v1:0/converse-stream" + ); + } + + #[test] + fn test_anthropic_messages_endpoint() { + let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Test Anthropic without prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Anthropic, + "/v1/messages", + "claude-3-opus", + false, + None + ), + "/v1/messages" + ); + + // Test Anthropic with prefix + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Anthropic, + "/v1/messages", + "claude-3-opus", + false, + Some("/api/v2") + ), + "/api/v2/messages" + ); + } + + #[test] + fn test_non_v1_request_paths() { + let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Test Groq with non-v1 path (should use default) + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Groq, + "/custom/path", + "llama2", + false, + None + ), + "/v1/chat/completions" + ); + + // Test Zhipu with non-v1 path + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/custom/path", + "chatglm", + false, + None + ), + "/v1/chat/completions" + ); + + // Test with prefix on non-v1 path + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Zhipu, + "/custom/path", + "chatglm", + false, + Some("/api/v2") + ), + "/api/v2/chat/completions" + ); + } + + #[test] + fn test_azure_openai_with_query_params() { + let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Test Azure without prefix - should include query params + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AzureOpenAI, + "/v1/chat/completions", + "gpt-4-deployment", + false, + None + ), + "/openai/deployments/gpt-4-deployment/chat/completions?api-version=2025-01-01-preview" + ); + + // Test Azure with prefix - prefix should replace /openai/deployments + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::AzureOpenAI, + "/v1/chat/completions", + "gpt-4-deployment", + false, + Some("/custom/azure/path") + ), + "/custom/azure/path/gpt-4-deployment/chat/completions?api-version=2025-01-01-preview" + ); + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 870530ab..1098185d 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -121,6 +121,7 @@ impl StreamContext { .as_ref() .unwrap_or(&"".to_string()), self.streaming_response, + self.llm_provider().base_url_path_prefix.as_deref(), ); if target_endpoint != request_path { self.set_http_request_header(":path", Some(&target_endpoint)); diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index ae1e2499..b9dcab81 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -43,7 +43,6 @@ llm_providers: access_key: $AWS_BEARER_TOKEN_BEDROCK base_url: https://bedrock-runtime.us-west-2.amazonaws.com - # Ollama Models - model: ollama/llama3.1 base_url: http://host.docker.internal:11434 diff --git a/docs/source/concepts/llm_providers/supported_providers.rst b/docs/source/concepts/llm_providers/supported_providers.rst index 2a58f328..be063885 100644 --- a/docs/source/concepts/llm_providers/supported_providers.rst +++ b/docs/source/concepts/llm_providers/supported_providers.rst @@ -36,7 +36,7 @@ All providers are configured in the ``llm_providers`` section of your ``arch_con - ``access_key``: API key for authentication (supports environment variables) - ``default``: Mark a model as the default (optional, boolean) - ``name``: Custom name for the provider instance (optional) -- ``base_url``: Custom endpoint URL (required for some providers) +- ``base_url``: Custom endpoint URL (required for some providers, optional for others - see :ref:`base_url_details`) Provider Categories ------------------- @@ -493,6 +493,8 @@ Zhipu AI Providers Requiring Base URL ---------------------------- +The following providers require a ``base_url`` parameter to be configured. For detailed information on base URL configuration including path prefix behavior and examples, see :ref:`base_url_details`. + Azure OpenAI ~~~~~~~~~~~~ @@ -616,6 +618,70 @@ For providers that implement the OpenAI API but aren't natively supported: base_url: http://localhost:8000 provider_interface: openai +.. _base_url_details: + +Base URL Configuration +---------------------- + +The ``base_url`` parameter allows you to specify custom endpoints for model providers. It supports both hostname and path components, enabling flexible routing to different API endpoints. + +**Format:** ``://[:][/]`` + +**Components:** + +- ``scheme``: ``http`` or ``https`` +- ``hostname``: API server hostname or IP address +- ``port``: Optional, defaults to 80 for http, 443 for https +- ``path``: Optional path prefix that **replaces** the provider's default API path + +**How Path Prefixes Work:** + +When you include a path in ``base_url``, it replaces the provider's default path prefix while preserving the endpoint suffix: + +- **Without path prefix**: Uses the provider's default path structure +- **With path prefix**: Your custom path replaces the provider's default prefix, then the endpoint suffix is appended + +**Configuration Examples:** + +.. code-block:: yaml + + llm_providers: + # Simple hostname only - uses provider's default path + - model: zhipu/glm-4.6 + access_key: $ZHIPU_API_KEY + base_url: https://api.z.ai + # Results in: https://api.z.ai/api/paas/v4/chat/completions + + # With custom path prefix - replaces provider's default path + - model: zhipu/glm-4.6 + access_key: $ZHIPU_API_KEY + base_url: https://api.z.ai/api/coding/paas/v4 + # Results in: https://api.z.ai/api/coding/paas/v4/chat/completions + + # Azure with custom path + - model: azure_openai/gpt-4 + access_key: $AZURE_API_KEY + base_url: https://mycompany.openai.azure.com/custom/deployment/path + # Results in: https://mycompany.openai.azure.com/custom/deployment/path/chat/completions + + # Behind a proxy or API gateway + - model: openai/gpt-4o + access_key: $OPENAI_API_KEY + base_url: https://proxy.company.com/ai-gateway/openai + # Results in: https://proxy.company.com/ai-gateway/openai/chat/completions + + # Local endpoint with custom port + - model: ollama/llama3.1 + base_url: http://localhost:8080 + # Results in: http://localhost:8080/v1/chat/completions + + # Custom provider with path prefix + - model: vllm/custom-model + access_key: $VLLM_API_KEY + base_url: https://vllm.example.com/models/v2 + provider_interface: openai + # Results in: https://vllm.example.com/models/v2/chat/completions + Advanced Configuration ----------------------