add support for model_providers

This commit is contained in:
Adil Hafeez 2025-09-30 12:18:29 -07:00
parent 2cebc0c85f
commit 92a8782332
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
10 changed files with 183 additions and 93 deletions

View file

@ -1,6 +1,6 @@
import json
import os
from cli.utils import convert_legacy_llm_providers
from cli.utils import convert_legacy_listeners
from jinja2 import Environment, FileSystemLoader
import yaml
from jsonschema import validate
@ -71,8 +71,17 @@ def validate_and_render_schema():
_ = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
listeners, llm_gateway, prompt_gateway = convert_legacy_llm_providers(
config_yaml.get("listeners"), config_yaml.get("llm_providers")
# Convert legacy llm_providers to model_providers
if "llm_providers" in config_yaml:
if "model_providers" in config_yaml:
raise Exception(
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
)
config_yaml["model_providers"] = config_yaml["llm_providers"]
del config_yaml["llm_providers"]
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
config_yaml.get("listeners"), config_yaml.get("model_providers")
)
config_yaml["listeners"] = listeners
@ -130,36 +139,39 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
updated_model_providers = []
model_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for listener in listeners:
if listener.get("llm_providers") is None or listener.get("llm_providers") == []:
if (
listener.get("model_providers") is None
or listener.get("model_providers") == []
):
continue
print("Processing listener with llm_providers: ", listener)
print("Processing listener with model_providers: ", listener)
name = listener.get("name", None)
for llm_provider in listener.get("llm_providers", []):
if llm_provider.get("usage", None):
llms_with_usage.append(llm_provider["name"])
if llm_provider.get("name") in llm_provider_name_set:
for model_provider in listener.get("model_providers", []):
if model_provider.get("usage", None):
llms_with_usage.append(model_provider["name"])
if model_provider.get("name") in model_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider"
)
model_name = llm_provider.get("model")
model_name = model_provider.get("model")
if model_name in model_name_keys:
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
if model_provider.get("name") is None:
model_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_provider_name_set.add(model_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
@ -171,7 +183,7 @@ def validate_and_render_schema():
# Validate azure_openai and ollama provider requires base_url
if (
provider == "azure_openai" or provider == "ollama"
) and llm_provider.get("base_url") is None:
) and model_provider.get("base_url") is None:
raise Exception(
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
)
@ -179,46 +191,48 @@ def validate_and_render_schema():
model_id = "/".join(model_name_tokens[1:])
if provider not in SUPPORTED_PROVIDERS:
if (
llm_provider.get("base_url", None) is None
or llm_provider.get("provider_interface", None) is None
model_provider.get("base_url", None) is None
or model_provider.get("provider_interface", None) is None
):
raise Exception(
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
)
provider = llm_provider.get("provider_interface", None)
elif llm_provider.get("provider_interface", None) is not None:
provider = model_provider.get("provider_interface", None)
elif model_provider.get("provider_interface", None) is not None:
raise Exception(
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
)
if model_id in model_name_keys:
raise Exception(
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
)
model_name_keys.add(model_id)
for routing_preference in llm_provider.get("routing_preferences", []):
for routing_preference in model_provider.get("routing_preferences", []):
if routing_preference.get("name") in model_usage_name_keys:
raise Exception(
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
)
model_usage_name_keys.add(routing_preference.get("name"))
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
llm_provider_name_set.add(llm_provider.get("name"))
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
model_provider["model"] = model_id
model_provider["provider_interface"] = provider
model_provider_name_set.add(model_provider.get("name"))
if model_provider.get("provider") and model_provider.get(
"provider_interface"
):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if model_provider.get("provider"):
provider = model_provider["provider"]
model_provider["provider_interface"] = provider
del model_provider["provider"]
updated_model_providers.append(model_provider)
if llm_provider.get("base_url", None):
base_url = llm_provider["base_url"]
if model_provider.get("base_url", None):
base_url = model_provider["base_url"]
urlparse_result = urlparse(base_url)
url_path = urlparse_result.path
if url_path and url_path != "/":
@ -240,22 +254,30 @@ def validate_and_render_schema():
else:
port = 443
endpoint = urlparse_result.hostname
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
model_provider["endpoint"] = endpoint
model_provider["port"] = port
model_provider["protocol"] = protocol
model_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
llms_with_endpoint.append(model_provider)
if len(model_usage_name_keys) > 0:
routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None)
if routing_llm_provider and routing_llm_provider not in llm_provider_name_set:
routing_model_provider = config_yaml.get("routing", {}).get(
"model_provider", None
)
if (
routing_model_provider
and routing_model_provider not in model_provider_name_set
):
raise Exception(
f"Routing llm_provider {routing_llm_provider} is not defined in llm_providers"
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
)
if routing_llm_provider is None and "arch-router" not in llm_provider_name_set:
updated_llm_providers.append(
if (
routing_model_provider is None
and "arch-router" not in model_provider_name_set
):
updated_model_providers.append(
{
"name": "arch-router",
"provider_interface": "arch",
@ -263,19 +285,19 @@ def validate_and_render_schema():
}
)
updated_llm_providers = []
updated_model_providers = []
for listener in listeners:
print("Processing listener: ", listener)
llm_providers = listener.get("llm_providers", None)
if llm_providers is not None and llm_providers != []:
model_providers = listener.get("model_providers", None)
if model_providers is not None and model_providers != []:
print("processing egress traffic listener")
print("updated_llm_providers: ", updated_llm_providers)
if updated_llm_providers is not None and updated_llm_providers != []:
print("updated_model_providers: ", updated_model_providers)
if updated_model_providers is not None and updated_model_providers != []:
raise Exception(
"Please provide llm_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with llm_providers"
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)
updated_llm_providers = deepcopy(llm_providers)
config_yaml["llm_providers"] = updated_llm_providers
updated_model_providers = deepcopy(model_providers)
config_yaml["model_providers"] = updated_model_providers
# Validate model aliases if present
if "model_aliases" in config_yaml:
@ -317,7 +339,7 @@ def validate_and_render_schema():
"arch_config": arch_config_string,
"arch_llm_config": arch_llm_config_string,
"arch_clusters": inferred_clusters,
"arch_llm_providers": updated_llm_providers,
"arch_model_providers": updated_model_providers,
"arch_tracing": arch_tracing,
"local_llms": llms_with_endpoint,
"agent_orchestrator": agent_orchestrator,