diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 0ca9d42d..9e9abac8 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -76,6 +76,16 @@ properties: type: string http_host: type: string + provider_interface: + type: string + enum: + - arch + - claude + - deepseek + - groq + - mistral + - openai + - gemini routing_preferences: type: array items: diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index c636813b..56d3869a 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -122,8 +122,17 @@ def validate_and_render_schema(): provider = model_name_tokens[0] model_id = "/".join(model_name_tokens[1:]) if provider not in SUPPORTED_PROVIDERS: + if ( + llm_provider.get("base_url", None) is None + or llm_provider.get("provider_interface", None) is None + ): + raise Exception( + f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" + ) + provider = llm_provider.get("provider_interface", None) + elif llm_provider.get("provider_interface", None) is not None: raise Exception( - f"Unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" + f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " ) if model_id in model_name_keys: @@ -181,7 +190,7 @@ def validate_and_render_schema(): llm_provider["protocol"] = protocol llms_with_endpoint.append(llm_provider) - if len(llms_with_usage) > 0: + if len(model_usage_name_keys) > 0: routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None) if routing_llm_provider and routing_llm_provider not in llm_provider_name_set: raise Exception( diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index b49cbb8d..cb3a42e6 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -34,5 +34,9 @@ llm_providers: - access_key: $GEMINI_API_KEY model: gemini/gemini-1.5-pro-latest + - model: custom/test-model + base_url: http://host.docker.internal:11223 + provider_interface: openai + tracing: random_sampling: 100