diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index b532117c..96e14856 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -33,6 +33,8 @@ properties: enum: - http - https + http_host: + type: string additionalProperties: false required: - endpoint @@ -66,6 +68,8 @@ properties: enum: - http - https + http_host: + type: string additionalProperties: false required: - name diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 588c6f66..5dfa60b3 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -537,7 +537,11 @@ static_resources: socket_address: address: {{ cluster.endpoint }} port_value: {{ cluster.port }} + {% if cluster.http_host %} + hostname: {{ cluster.http_host }} + {% else %} hostname: {{ cluster.endpoint }} + {% endif %} {% if cluster.protocol == "https" %} transport_socket: name: envoy.transport_sockets.tls @@ -566,7 +570,11 @@ static_resources: socket_address: address: {{ local_llm_provider.endpoint }} port_value: {{ local_llm_provider.port }} + {% if local_llm_provider.http_host %} + hostname: {{ local_llm_provider.http_host }} + {% else %} hostname: {{ local_llm_provider.endpoint }} + {% endif %} {% if local_llm_provider.protocol == "https" %} transport_socket: name: envoy.transport_sockets.tls diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index e535894b..447585fb 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -16,6 +16,20 @@ ARCH_CONFIG_SCHEMA_FILE = os.getenv( ) +def get_endpoint_and_port(endpoint, protocol): + endpoint_tokens = endpoint.split(":") + if len(endpoint_tokens) > 1: + endpoint = endpoint_tokens[0] + port = int(endpoint_tokens[1]) + return endpoint, port + else: + if protocol == "http": + port = 80 + else: + port = 443 + return endpoint, port + + def validate_and_render_schema(): env = Environment(loader=FileSystemLoader("./")) template = env.get_template("envoy.template.yaml") @@ -42,9 +56,11 @@ def validate_and_render_schema(): for name, endpoint_details in endpoints.items(): inferred_clusters[name] = endpoint_details endpoint = inferred_clusters[name]["endpoint"] - if len(endpoint.split(":")) > 1: - inferred_clusters[name]["endpoint"] = endpoint.split(":")[0] - inferred_clusters[name]["port"] = int(endpoint.split(":")[1]) + protocol = inferred_clusters[name].get("protocol", "http") + ( + inferred_clusters[name]["endpoint"], + inferred_clusters[name]["port"], + ) = get_endpoint_and_port(endpoint, protocol) print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters)) @@ -77,9 +93,10 @@ def validate_and_render_schema(): if llm_provider.get("endpoint", None): endpoint = llm_provider["endpoint"] - if len(endpoint.split(":")) > 1: - llm_provider["endpoint"] = endpoint.split(":")[0] - llm_provider["port"] = int(endpoint.split(":")[1]) + protocol = llm_provider.get("protocol", "http") + llm_provider["endpoint"], llm_provider["port"] = get_endpoint_and_port( + endpoint, protocol + ) llms_with_endpoint.append(llm_provider) config_yaml["llm_providers"] = updated_llm_providers diff --git a/demos/currency_exchange/arch_config.yaml b/demos/currency_exchange/arch_config.yaml index f8776c48..f51b5904 100644 --- a/demos/currency_exchange/arch_config.yaml +++ b/demos/currency_exchange/arch_config.yaml @@ -44,7 +44,7 @@ prompt_targets: endpoints: frankfurther_api: - endpoint: api.frankfurter.dev:443 + endpoint: api.frankfurter.dev protocol: https tracing: