diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 1432c0b9..cab5cf17 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -62,7 +62,7 @@ properties: properties: name: type: string - # this field is deprecated, use provider_interface instead + # provider field is deprecated, use provider_interface instead provider: type: string enum: @@ -78,8 +78,11 @@ properties: type: string default: type: boolean + # endpoint field is deprecated, use base_url instead endpoint: type: string + base_url: + type: string protocol: type: string enum: diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 7392849e..f0aae28a 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -3,6 +3,7 @@ import os from jinja2 import Environment, FileSystemLoader import yaml from jsonschema import validate +from urllib.parse import urlparse ENVOY_CONFIG_TEMPLATE_FILE = os.getenv( "ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml" @@ -91,6 +92,9 @@ def validate_and_render_schema(): del llm_provider["provider"] updated_llm_providers.append(llm_provider) + if llm_provider.get("endpoint") and llm_provider.get("base_url"): + raise Exception("Please provide either endpoint or base_url, not both") + if llm_provider.get("endpoint", None): endpoint = llm_provider["endpoint"] protocol = llm_provider.get("protocol", "http") @@ -98,6 +102,30 @@ def validate_and_render_schema(): endpoint, protocol ) llms_with_endpoint.append(llm_provider) + elif llm_provider.get("base_url", None): + base_url = llm_provider["base_url"] + urlparse_result = urlparse(base_url) + if llm_provider.get("port"): + raise Exception("Please provider port in base_url") + if urlparse_result.scheme == "" or urlparse_result.scheme not in [ + "http", + "https", + ]: + raise Exception( + "Please provide a valid URL with scheme (http/https) in base_url" + ) + protocol = urlparse_result.scheme + port = urlparse_result.port + if port is None: + if protocol == "http": + port = 80 + else: + port = 443 + endpoint = urlparse_result.hostname + llm_provider["endpoint"] = endpoint + llm_provider["port"] = port + llm_provider["protocol"] = protocol + llms_with_endpoint.append(llm_provider) config_yaml["llm_providers"] = updated_llm_providers diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index b4f87698..4cccd718 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -33,8 +33,7 @@ llm_providers: access_key: $DEEPSEEK_API_KEY provider_interface: openai model: deepseek-reasoner - endpoint: api.deepseek.com - protocol: https + base_url: https://api.deepseek.com/ tracing: random_sampling: 100