diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 8f0dcefd..661b9b51 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -1,9 +1,11 @@ import json import os +from cli.utils import convert_legacy_llm_providers from jinja2 import Environment, FileSystemLoader import yaml from jsonschema import validate from urllib.parse import urlparse +from copy import deepcopy SUPPORTED_PROVIDERS = [ @@ -14,10 +16,6 @@ SUPPORTED_PROVIDERS = [ "openai", "gemini", "anthropic", - "together_ai", - "azure_openai", - "xai", - "ollama", ] @@ -69,17 +67,49 @@ def validate_and_render_schema(): _ = yaml.safe_load(arch_config_schema) inferred_clusters = {} + listeners, llm_gateway, prompt_gateway = convert_legacy_llm_providers( + config_yaml.get("listeners"), config_yaml.get("llm_providers") + ) + + config_yaml["listeners"] = listeners + endpoints = config_yaml.get("endpoints", {}) + # Process agents section and convert to endpoints + agents = config_yaml.get("agents", []) + for agent in agents: + agent_name = agent.get("name") + agent_endpoint = agent.get("endpoint") + + if agent_name and agent_endpoint: + urlparse_result = urlparse(agent_endpoint) + if urlparse_result.scheme and urlparse_result.hostname: + protocol = urlparse_result.scheme + + port = urlparse_result.port + if port is None: + if protocol == "http": + port = 80 + else: + port = 443 + + endpoints[agent_name] = { + "endpoint": urlparse_result.hostname, + "port": port, + "protocol": protocol, + } + # override the inferred clusters with the ones defined in the config for name, endpoint_details in endpoints.items(): inferred_clusters[name] = endpoint_details - endpoint = inferred_clusters[name]["endpoint"] - protocol = inferred_clusters[name].get("protocol", "http") - ( - inferred_clusters[name]["endpoint"], - inferred_clusters[name]["port"], - ) = get_endpoint_and_port(endpoint, protocol) + # Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones + if "port" not in endpoint_details: + endpoint = inferred_clusters[name]["endpoint"] + protocol = inferred_clusters[name].get("protocol", "http") + ( + inferred_clusters[name]["endpoint"], + inferred_clusters[name]["port"], + ) = get_endpoint_and_port(endpoint, protocol) print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters)) @@ -96,105 +126,162 @@ def validate_and_render_schema(): arch_tracing = config_yaml.get("tracing", {}) llms_with_endpoint = [] - updated_llm_providers = [] + updated_llm_providers = [] llm_provider_name_set = set() + llms_with_usage = [] model_name_keys = set() model_usage_name_keys = set() - for llm_provider in config_yaml["llm_providers"]: - if llm_provider.get("name") in llm_provider_name_set: - raise Exception( - f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" - ) - model_name = llm_provider.get("model") - if model_name in model_name_keys: - raise Exception( - f"Duplicate model name {model_name}, please provide unique model name for each llm_provider" - ) + # # legacy listeners + # # check if type is array or object + # # if its dict its legacy format let's convert it to array + # prompt_gateway_listener = { + # "name": "ingress_traffic", + # "port": 10000, + # "address": "0.0.0.0", + # "timeout": "30s", + # "protocol": "openai", + # } + # llm_gateway_listener = { + # "name": "egress_traffic", + # "port": 12000, + # "address": "0.0.0.0", + # "timeout": "30s", + # "llm_providers": [], + # "protocol": "openai", + # } + # if isinstance(config_yaml["listeners"], dict): + # ingress_traffic = config_yaml["listeners"].get("ingress_traffic", None) + # egress_traffic = config_yaml["listeners"].get("egress_traffic", {}) + # config_yaml["listeners"] = [] - model_name_keys.add(model_name) - if llm_provider.get("name") is None: - llm_provider["name"] = model_name + # llm_providers = [] + # if config_yaml.get("llm_providers"): + # llm_providers = config_yaml["llm_providers"] + # del config_yaml["llm_providers"] + # llm_gateway_listener["port"] = egress_traffic.get( + # "port", llm_gateway_listener["port"] + # ) + # llm_gateway_listener["address"] = egress_traffic.get( + # "address", llm_gateway_listener["address"] + # ) + # llm_gateway_listener["timeout"] = egress_traffic.get( + # "timeout", llm_gateway_listener["timeout"] + # ) + # llm_gateway_listener["llm_providers"] = llm_providers + # config_yaml["listeners"].append(llm_gateway_listener) - llm_provider_name_set.add(llm_provider.get("name")) + # if ingress_traffic: + # prompt_gateway_listener["port"] = ingress_traffic.get( + # "port", prompt_gateway_listener["port"] + # ) + # prompt_gateway_listener["address"] = ingress_traffic.get( + # "address", prompt_gateway_listener["address"] + # ) + # prompt_gateway_listener["timeout"] = ingress_traffic.get( + # "timeout", prompt_gateway_listener["timeout"] + # ) + # config_yaml["listeners"].append(prompt_gateway_listener) - model_name_tokens = model_name.split("/") - if len(model_name_tokens) < 2: - raise Exception( - f"Invalid model name {model_name}. Please provide model name in the format /." - ) - provider = model_name_tokens[0] - # Validate azure_openai and ollama provider requires base_url - if (provider == "azure_openai" or provider == "ollama") and llm_provider.get( - "base_url" - ) is None: - raise Exception( - f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" - ) + for listener in listeners: + if listener.get("llm_providers") is None or listener.get("llm_providers") == []: + continue + print("Processing listener with llm_providers: ", listener) + name = listener.get("name", None) - model_id = "/".join(model_name_tokens[1:]) - if provider not in SUPPORTED_PROVIDERS: - if ( - llm_provider.get("base_url", None) is None - or llm_provider.get("provider_interface", None) is None - ): + for llm_provider in listener.get("llm_providers", []): + if llm_provider.get("usage", None): + llms_with_usage.append(llm_provider["name"]) + if llm_provider.get("name") in llm_provider_name_set: raise Exception( - f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" + f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" ) - provider = llm_provider.get("provider_interface", None) - elif llm_provider.get("provider_interface", None) is not None: - raise Exception( - f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " - ) - if model_id in model_name_keys: - raise Exception( - f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider" - ) - model_name_keys.add(model_id) - - for routing_preference in llm_provider.get("routing_preferences", []): - if routing_preference.get("name") in model_usage_name_keys: + model_name = llm_provider.get("model") + if model_name in model_name_keys: raise Exception( - f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference" + f"Duplicate model name {model_name}, please provide unique model name for each llm_provider" ) - model_usage_name_keys.add(routing_preference.get("name")) + model_name_keys.add(model_name) + if llm_provider.get("name") is None: + llm_provider["name"] = model_name - llm_provider["model"] = model_id - llm_provider["provider_interface"] = provider - updated_llm_providers.append(llm_provider) + model_name_tokens = model_name.split("/") + if len(model_name_tokens) < 2: + raise Exception( + f"Invalid model name {model_name}. Please provide model name in the format /." + ) + provider = model_name_tokens[0] + model_id = "/".join(model_name_tokens[1:]) + if provider not in SUPPORTED_PROVIDERS: + if ( + llm_provider.get("base_url", None) is None + or llm_provider.get("provider_interface", None) is None + ): + raise Exception( + f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" + ) + provider = llm_provider.get("provider_interface", None) + elif llm_provider.get("provider_interface", None) is not None: + raise Exception( + f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " + ) - if llm_provider.get("base_url", None): - base_url = llm_provider["base_url"] - urlparse_result = urlparse(base_url) - url_path = urlparse_result.path - if url_path and url_path != "/": + if model_id in model_name_keys: raise Exception( - f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'." + f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider" ) - if urlparse_result.scheme == "" or urlparse_result.scheme not in [ - "http", - "https", - ]: + model_name_keys.add(model_id) + + for routing_preference in llm_provider.get("routing_preferences", []): + if routing_preference.get("name") in model_usage_name_keys: + raise Exception( + f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference" + ) + model_usage_name_keys.add(routing_preference.get("name")) + + llm_provider["model"] = model_id + llm_provider["provider_interface"] = provider + llm_provider_name_set.add(llm_provider.get("name")) + provider = None + if llm_provider.get("provider") and llm_provider.get("provider_interface"): raise Exception( - "Please provide a valid URL with scheme (http/https) in base_url" + "Please provide either provider or provider_interface, not both" ) - protocol = urlparse_result.scheme - port = urlparse_result.port - if port is None: - if protocol == "http": - port = 80 - else: - port = 443 - endpoint = urlparse_result.hostname - llm_provider["endpoint"] = endpoint - llm_provider["port"] = port - llm_provider["protocol"] = protocol - llm_provider["cluster_name"] = ( - provider + "_" + endpoint - ) # make name unique by appending endpoint - llms_with_endpoint.append(llm_provider) + if llm_provider.get("provider"): + provider = llm_provider["provider"] + llm_provider["provider_interface"] = provider + del llm_provider["provider"] + updated_llm_providers.append(llm_provider) + + if llm_provider.get("base_url", None): + base_url = llm_provider["base_url"] + urlparse_result = urlparse(base_url) + url_path = urlparse_result.path + if url_path and url_path != "/": + raise Exception( + f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'." + ) + if urlparse_result.scheme == "" or urlparse_result.scheme not in [ + "http", + "https", + ]: + raise Exception( + "Please provide a valid URL with scheme (http/https) in base_url" + ) + protocol = urlparse_result.scheme + port = urlparse_result.port + if port is None: + if protocol == "http": + port = 80 + else: + port = 443 + endpoint = urlparse_result.hostname + llm_provider["endpoint"] = endpoint + llm_provider["port"] = port + llm_provider["protocol"] = protocol + llms_with_endpoint.append(llm_provider) if len(model_usage_name_keys) > 0: routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None) @@ -211,6 +298,18 @@ def validate_and_render_schema(): } ) + updated_llm_providers = [] + for listener in listeners: + print("Processing listener: ", listener) + llm_providers = listener.get("llm_providers", None) + if llm_providers is not None and llm_providers != []: + print("processing egress traffic listener") + print("updated_llm_providers: ", updated_llm_providers) + if updated_llm_providers is not None and updated_llm_providers != []: + raise Exception( + "Please provide llm_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with llm_providers" + ) + updated_llm_providers = deepcopy(llm_providers) config_yaml["llm_providers"] = updated_llm_providers # Validate model aliases if present @@ -226,24 +325,6 @@ def validate_and_render_schema(): arch_config_string = yaml.dump(config_yaml) arch_llm_config_string = yaml.dump(config_yaml) - prompt_gateway_listener = config_yaml.get("listeners", {}).get( - "ingress_traffic", {} - ) - if prompt_gateway_listener.get("port") == None: - prompt_gateway_listener["port"] = 10000 # default port for prompt gateway - if prompt_gateway_listener.get("address") == None: - prompt_gateway_listener["address"] = "127.0.0.1" - if prompt_gateway_listener.get("timeout") == None: - prompt_gateway_listener["timeout"] = "10s" - - llm_gateway_listener = config_yaml.get("listeners", {}).get("egress_traffic", {}) - if llm_gateway_listener.get("port") == None: - llm_gateway_listener["port"] = 12000 # default port for llm gateway - if llm_gateway_listener.get("address") == None: - llm_gateway_listener["address"] = "127.0.0.1" - if llm_gateway_listener.get("timeout") == None: - llm_gateway_listener["timeout"] = "10s" - use_agent_orchestrator = config_yaml.get("overrides", {}).get( "use_agent_orchestrator", False ) @@ -266,15 +347,16 @@ def validate_and_render_schema(): print("agent_orchestrator: ", agent_orchestrator) data = { - "prompt_gateway_listener": prompt_gateway_listener, - "llm_gateway_listener": llm_gateway_listener, + "prompt_gateway_listener": prompt_gateway, + "llm_gateway_listener": llm_gateway, "arch_config": arch_config_string, "arch_llm_config": arch_llm_config_string, "arch_clusters": inferred_clusters, - "arch_llm_providers": config_yaml["llm_providers"], + "arch_llm_providers": updated_llm_providers, "arch_tracing": arch_tracing, "local_llms": llms_with_endpoint, "agent_orchestrator": agent_orchestrator, + "listeners": listeners, } rendered = template.render(data)