draft commit to add support for xAI, TogehterAI, AzureOpenAI (#570)

* draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI

* fixing failing tests and updating rederend config file

* Update arch_config_with_aliases.yaml

* adding the AZURE_API_KEY to the GH workflow for e2e

* fixing GH secerts

* adding valdiating for azure_openai

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-167.local>
This commit is contained in:
Salman Paracha 2025-09-18 18:36:30 -07:00 committed by GitHub
parent b56311f458
commit 8d0b468345
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 166 additions and 24 deletions

View file

@ -14,6 +14,9 @@ SUPPORTED_PROVIDERS = [
"openai",
"gemini",
"anthropic",
"together_ai",
"azure_openai",
"xai",
]
@ -92,15 +95,12 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for llm_provider in config_yaml["llm_providers"]:
if llm_provider.get("usage", None):
llms_with_usage.append(llm_provider["name"])
if llm_provider.get("name") in llm_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
@ -111,16 +111,25 @@ def validate_and_render_schema():
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
)
provider = model_name_tokens[0]
# Validate azure_openai provider requires base_url
if provider == "azure_openai" and llm_provider.get("base_url") is None:
raise Exception(
f"Provider 'azure_openai' requires 'base_url' to be set for model {model_name}"
)
model_id = "/".join(model_name_tokens[1:])
if provider not in SUPPORTED_PROVIDERS:
if (
@ -151,16 +160,6 @@ def validate_and_render_schema():
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
llm_provider_name_set.add(llm_provider.get("name"))
provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if llm_provider.get("base_url", None):
@ -189,6 +188,9 @@ def validate_and_render_schema():
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
if len(model_usage_name_keys) > 0: