draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI

This commit is contained in:
Salman Paracha 2025-09-17 22:47:33 -07:00
parent b56311f458
commit 79ff4bb164
7 changed files with 170 additions and 24 deletions

View file

@ -127,7 +127,7 @@ static_resources:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% set llm_cluster_name = provider.cluster_name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
@ -421,7 +421,7 @@ static_resources:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% set llm_cluster_name = provider.cluster_name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
@ -576,6 +576,82 @@ static_resources:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: xai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.x.ai
port_value: 443
hostname: "api.x.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.x.ai
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: together_ai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.together.xyz
port_value: 443
hostname: "api.together.xyz"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.together.xyz
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: lambda_ai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.lambda.ai
port_value: 443
hostname: "api.lambda.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.lambda.ai
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: gemini
connect_timeout: 0.5s
type: LOGICAL_DNS
@ -742,13 +818,13 @@ static_resources:
{% endfor %}
{% for local_llm_provider in local_llms %}
- name: {{ local_llm_provider.name }}
- name: {{ local_llm_provider.cluster_name }}
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ local_llm_provider.name }}
cluster_name: {{ local_llm_provider.cluster_name }}
endpoints:
- lb_endpoints:
- endpoint:

View file

@ -14,6 +14,10 @@ SUPPORTED_PROVIDERS = [
"openai",
"gemini",
"anthropic",
"together_ai",
"lambda_ai",
"azure_openai",
"xai",
]
@ -92,15 +96,12 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for llm_provider in config_yaml["llm_providers"]:
if llm_provider.get("usage", None):
llms_with_usage.append(llm_provider["name"])
if llm_provider.get("name") in llm_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
@ -111,10 +112,13 @@ def validate_and_render_schema():
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
@ -151,16 +155,6 @@ def validate_and_render_schema():
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
llm_provider_name_set.add(llm_provider.get("name"))
provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if llm_provider.get("base_url", None):
@ -189,6 +183,9 @@ def validate_and_render_schema():
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
if len(model_usage_name_keys) > 0: