From dbc4b2d68b58859b88d59d4d66112fd70a105f91 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 3 Feb 2025 14:28:27 -0800 Subject: [PATCH] infer port from protocol if port is not specified --- arch/tools/cli/config_generator.py | 29 +++++++++++++++++++----- demos/currency_exchange/arch_config.yaml | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index e535894b..447585fb 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -16,6 +16,20 @@ ARCH_CONFIG_SCHEMA_FILE = os.getenv( ) +def get_endpoint_and_port(endpoint, protocol): + endpoint_tokens = endpoint.split(":") + if len(endpoint_tokens) > 1: + endpoint = endpoint_tokens[0] + port = int(endpoint_tokens[1]) + return endpoint, port + else: + if protocol == "http": + port = 80 + else: + port = 443 + return endpoint, port + + def validate_and_render_schema(): env = Environment(loader=FileSystemLoader("./")) template = env.get_template("envoy.template.yaml") @@ -42,9 +56,11 @@ def validate_and_render_schema(): for name, endpoint_details in endpoints.items(): inferred_clusters[name] = endpoint_details endpoint = inferred_clusters[name]["endpoint"] - if len(endpoint.split(":")) > 1: - inferred_clusters[name]["endpoint"] = endpoint.split(":")[0] - inferred_clusters[name]["port"] = int(endpoint.split(":")[1]) + protocol = inferred_clusters[name].get("protocol", "http") + ( + inferred_clusters[name]["endpoint"], + inferred_clusters[name]["port"], + ) = get_endpoint_and_port(endpoint, protocol) print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters)) @@ -77,9 +93,10 @@ def validate_and_render_schema(): if llm_provider.get("endpoint", None): endpoint = llm_provider["endpoint"] - if len(endpoint.split(":")) > 1: - llm_provider["endpoint"] = endpoint.split(":")[0] - llm_provider["port"] = int(endpoint.split(":")[1]) + protocol = llm_provider.get("protocol", "http") + llm_provider["endpoint"], llm_provider["port"] = get_endpoint_and_port( + endpoint, protocol + ) llms_with_endpoint.append(llm_provider) config_yaml["llm_providers"] = updated_llm_providers diff --git a/demos/currency_exchange/arch_config.yaml b/demos/currency_exchange/arch_config.yaml index f8776c48..f51b5904 100644 --- a/demos/currency_exchange/arch_config.yaml +++ b/demos/currency_exchange/arch_config.yaml @@ -44,7 +44,7 @@ prompt_targets: endpoints: frankfurther_api: - endpoint: api.frankfurter.dev:443 + endpoint: api.frankfurter.dev protocol: https tracing: