diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 90160fb3..9128ba36 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -4,6 +4,7 @@ from jinja2 import Environment, FileSystemLoader import yaml from jsonschema import validate from urllib.parse import urlparse +from copy import deepcopy SUPPORTED_PROVIDERS = [ @@ -101,40 +102,62 @@ def validate_and_render_schema(): # check if type is array or object # if its dict its legacy format let's convert it to array - prompt_gateway_listener = None - llm_gateway_listener = None + prompt_gateway_listener = { + "name": "ingress_traffic", + "port": 10000, + "address": "0.0.0.0", + "timeout": "30s", + "protocol": "openai", + } + llm_gateway_listener = { + "name": "egress_traffic", + "port": 12000, + "address": "0.0.0.0", + "timeout": "30s", + "llm_providers": [], + "protocol": "openai", + } if isinstance(config_yaml["listeners"], dict): - egress_traffic = config_yaml["listeners"].get("egress_traffic", None) ingress_traffic = config_yaml["listeners"].get("ingress_traffic", None) + egress_traffic = config_yaml["listeners"].get("egress_traffic", {}) config_yaml["listeners"] = [] + + llm_providers = [] + if config_yaml.get("llm_providers"): + llm_providers = config_yaml["llm_providers"] + del config_yaml["llm_providers"] + llm_gateway_listener["port"] = egress_traffic.get( + "port", llm_gateway_listener["port"] + ) + llm_gateway_listener["address"] = egress_traffic.get( + "address", llm_gateway_listener["address"] + ) + llm_gateway_listener["timeout"] = egress_traffic.get( + "timeout", llm_gateway_listener["timeout"] + ) + llm_gateway_listener["llm_providers"] = llm_providers + config_yaml["listeners"].append(llm_gateway_listener) + if ingress_traffic: - prompt_gateway_listener = { - "name": "ingress_traffic", - "port": ingress_traffic.get("port", 10000), - "address": ingress_traffic.get("address", "0.0.0.0"), - "timeout": ingress_traffic.get("timeout", "30s"), - "protocol": "openai", - } + prompt_gateway_listener["port"] = ingress_traffic.get( + "port", prompt_gateway_listener["port"] + ) + prompt_gateway_listener["address"] = ingress_traffic.get( + "address", prompt_gateway_listener["address"] + ) + prompt_gateway_listener["timeout"] = ingress_traffic.get( + "timeout", prompt_gateway_listener["timeout"] + ) config_yaml["listeners"].append(prompt_gateway_listener) - if egress_traffic: - llm_providers = [] - if config_yaml.get("llm_providers"): - llm_providers = config_yaml["llm_providers"] - del config_yaml["llm_providers"] - llm_gateway_listener = { - "name": "egress_traffic", - "port": egress_traffic.get("port", 12000), - "address": egress_traffic.get("address", "0.0.0.0"), - "timeout": egress_traffic.get("timeout", "30s"), - "llm_providers": llm_providers, - "protocol": "openai", - } - config_yaml["listeners"].append(llm_gateway_listener) for listener in config_yaml["listeners"]: print("Processing listener: ", listener) name = listener.get("name", None) + # TODO: for now we only support llm_providers under egress_traffic listener + if name != "egress_traffic": + continue + for llm_provider in listener.get("llm_providers", []): if llm_provider.get("usage", None): llms_with_usage.append(llm_provider["name"]) @@ -244,8 +267,12 @@ def validate_and_render_schema(): ) for listener in config_yaml["listeners"]: + print("Processing listener: ", listener) if listener.get("name") == "egress_traffic": - listener["llm_providers"] = updated_llm_providers + print("processing egress traffic listener") + print("updated_llm_providers: ", updated_llm_providers) + listener["llm_providers"] = deepcopy(updated_llm_providers) + config_yaml["llm_providers"] = updated_llm_providers arch_config_string = yaml.dump(config_yaml) arch_llm_config_string = yaml.dump(config_yaml)