diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 661b9b51..8971249a 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -16,6 +16,9 @@ SUPPORTED_PROVIDERS = [ "openai", "gemini", "anthropic", + "together_ai", + "azure_openai", + "xai", ] @@ -207,12 +210,21 @@ def validate_and_render_schema(): if llm_provider.get("name") is None: llm_provider["name"] = model_name + llm_provider_name_set.add(llm_provider.get("name")) + model_name_tokens = model_name.split("/") if len(model_name_tokens) < 2: raise Exception( f"Invalid model name {model_name}. Please provide model name in the format /." ) provider = model_name_tokens[0] + + # Validate azure_openai provider requires base_url + if provider == "azure_openai" and llm_provider.get("base_url") is None: + raise Exception( + f"Provider 'azure_openai' requires 'base_url' to be set for model {model_name}" + ) + model_id = "/".join(model_name_tokens[1:]) if provider not in SUPPORTED_PROVIDERS: if ( @@ -281,6 +293,9 @@ def validate_and_render_schema(): llm_provider["endpoint"] = endpoint llm_provider["port"] = port llm_provider["protocol"] = protocol + llm_provider["cluster_name"] = ( + provider + "_" + endpoint + ) # make name unique by appending endpoint llms_with_endpoint.append(llm_provider) if len(model_usage_name_keys) > 0: