Infer port from protocol if port is not specified and add ability to override hostname in clusters def (#389)

This commit is contained in:
Adil Hafeez 2025-02-03 14:51:59 -08:00 committed by GitHub
parent 25692bbbfc
commit 962727f244
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 36 additions and 7 deletions

View file

@ -33,6 +33,8 @@ properties:
enum:
- http
- https
http_host:
type: string
additionalProperties: false
required:
- endpoint
@ -66,6 +68,8 @@ properties:
enum:
- http
- https
http_host:
type: string
additionalProperties: false
required:
- name

View file

@ -537,7 +537,11 @@ static_resources:
socket_address:
address: {{ cluster.endpoint }}
port_value: {{ cluster.port }}
{% if cluster.http_host %}
hostname: {{ cluster.http_host }}
{% else %}
hostname: {{ cluster.endpoint }}
{% endif %}
{% if cluster.protocol == "https" %}
transport_socket:
name: envoy.transport_sockets.tls
@ -566,7 +570,11 @@ static_resources:
socket_address:
address: {{ local_llm_provider.endpoint }}
port_value: {{ local_llm_provider.port }}
{% if local_llm_provider.http_host %}
hostname: {{ local_llm_provider.http_host }}
{% else %}
hostname: {{ local_llm_provider.endpoint }}
{% endif %}
{% if local_llm_provider.protocol == "https" %}
transport_socket:
name: envoy.transport_sockets.tls

View file

@ -16,6 +16,20 @@ ARCH_CONFIG_SCHEMA_FILE = os.getenv(
)
def get_endpoint_and_port(endpoint, protocol):
endpoint_tokens = endpoint.split(":")
if len(endpoint_tokens) > 1:
endpoint = endpoint_tokens[0]
port = int(endpoint_tokens[1])
return endpoint, port
else:
if protocol == "http":
port = 80
else:
port = 443
return endpoint, port
def validate_and_render_schema():
env = Environment(loader=FileSystemLoader("./"))
template = env.get_template("envoy.template.yaml")
@ -42,9 +56,11 @@ def validate_and_render_schema():
for name, endpoint_details in endpoints.items():
inferred_clusters[name] = endpoint_details
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
protocol = inferred_clusters[name].get("protocol", "http")
(
inferred_clusters[name]["endpoint"],
inferred_clusters[name]["port"],
) = get_endpoint_and_port(endpoint, protocol)
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
@ -77,9 +93,10 @@ def validate_and_render_schema():
if llm_provider.get("endpoint", None):
endpoint = llm_provider["endpoint"]
if len(endpoint.split(":")) > 1:
llm_provider["endpoint"] = endpoint.split(":")[0]
llm_provider["port"] = int(endpoint.split(":")[1])
protocol = llm_provider.get("protocol", "http")
llm_provider["endpoint"], llm_provider["port"] = get_endpoint_and_port(
endpoint, protocol
)
llms_with_endpoint.append(llm_provider)
config_yaml["llm_providers"] = updated_llm_providers

View file

@ -44,7 +44,7 @@ prompt_targets:
endpoints:
frankfurther_api:
endpoint: api.frankfurter.dev:443
endpoint: api.frankfurter.dev
protocol: https
tracing: