draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI

This commit is contained in:
Salman Paracha 2025-09-17 22:47:33 -07:00
parent b56311f458
commit 79ff4bb164
7 changed files with 170 additions and 24 deletions

View file

@ -14,6 +14,10 @@ SUPPORTED_PROVIDERS = [
"openai",
"gemini",
"anthropic",
"together_ai",
"lambda_ai",
"azure_openai",
"xai",
]
@ -92,15 +96,12 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for llm_provider in config_yaml["llm_providers"]:
if llm_provider.get("usage", None):
llms_with_usage.append(llm_provider["name"])
if llm_provider.get("name") in llm_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
@ -111,10 +112,13 @@ def validate_and_render_schema():
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
@ -151,16 +155,6 @@ def validate_and_render_schema():
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
llm_provider_name_set.add(llm_provider.get("name"))
provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if llm_provider.get("base_url", None):
@ -189,6 +183,9 @@ def validate_and_render_schema():
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
if len(model_usage_name_keys) > 0: