add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -1,9 +1,11 @@
import json
import os
from cli.utils import convert_legacy_listeners
from jinja2 import Environment, FileSystemLoader
import yaml
from jsonschema import validate
from urllib.parse import urlparse
from copy import deepcopy
SUPPORTED_PROVIDERS = [
@ -72,17 +74,58 @@ def validate_and_render_schema():
_ = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
# Convert legacy llm_providers to model_providers
if "llm_providers" in config_yaml:
if "model_providers" in config_yaml:
raise Exception(
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
)
config_yaml["model_providers"] = config_yaml["llm_providers"]
del config_yaml["llm_providers"]
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
config_yaml.get("listeners"), config_yaml.get("model_providers")
)
config_yaml["listeners"] = listeners
endpoints = config_yaml.get("endpoints", {})
# Process agents section and convert to endpoints
agents = config_yaml.get("agents", [])
for agent in agents:
agent_id = agent.get("id")
agent_endpoint = agent.get("url")
if agent_id and agent_endpoint:
urlparse_result = urlparse(agent_endpoint)
if urlparse_result.scheme and urlparse_result.hostname:
protocol = urlparse_result.scheme
port = urlparse_result.port
if port is None:
if protocol == "http":
port = 80
else:
port = 443
endpoints[agent_id] = {
"endpoint": urlparse_result.hostname,
"port": port,
"protocol": protocol,
}
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
inferred_clusters[name] = endpoint_details
endpoint = inferred_clusters[name]["endpoint"]
protocol = inferred_clusters[name].get("protocol", "http")
(
inferred_clusters[name]["endpoint"],
inferred_clusters[name]["port"],
) = get_endpoint_and_port(endpoint, protocol)
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
if "port" not in endpoint_details:
endpoint = inferred_clusters[name]["endpoint"]
protocol = inferred_clusters[name].get("protocol", "http")
(
inferred_clusters[name]["endpoint"],
inferred_clusters[name]["port"],
) = get_endpoint_and_port(endpoint, protocol)
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
@ -99,114 +142,148 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
updated_model_providers = []
model_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for llm_provider in config_yaml["llm_providers"]:
if llm_provider.get("name") in llm_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
)
model_name = llm_provider.get("model")
if model_name in model_name_keys:
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
)
print("listeners: ", listeners)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
)
provider = model_name_tokens[0]
# Validate azure_openai and ollama provider requires base_url
for listener in listeners:
if (
provider == "azure_openai" or provider == "ollama" or provider == "qwen"
) and llm_provider.get("base_url") is None:
raise Exception(
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
)
listener.get("model_providers") is None
or listener.get("model_providers") == []
):
continue
print("Processing listener with model_providers: ", listener)
name = listener.get("name", None)
model_id = "/".join(model_name_tokens[1:])
if provider not in SUPPORTED_PROVIDERS:
for model_provider in listener.get("model_providers", []):
if model_provider.get("usage", None):
llms_with_usage.append(model_provider["name"])
if model_provider.get("name") in model_provider_name_set:
raise Exception(
f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider"
)
model_name = model_provider.get("model")
print("Processing model_provider: ", model_provider)
if model_name in model_name_keys:
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
)
model_name_keys.add(model_name)
if model_provider.get("name") is None:
model_provider["name"] = model_name
model_provider_name_set.add(model_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
)
provider = model_name_tokens[0]
# Validate azure_openai and ollama provider requires base_url
if (
llm_provider.get("base_url", None) is None
or llm_provider.get("provider_interface", None) is None
provider == "azure_openai" or provider == "ollama" or provider == "qwen"
) and model_provider.get("base_url") is None:
raise Exception(
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
)
model_id = "/".join(model_name_tokens[1:])
if provider not in SUPPORTED_PROVIDERS:
if (
model_provider.get("base_url", None) is None
or model_provider.get("provider_interface", None) is None
):
raise Exception(
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
)
provider = model_provider.get("provider_interface", None)
elif model_provider.get("provider_interface", None) is not None:
raise Exception(
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
)
if model_id in model_name_keys:
raise Exception(
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
)
model_name_keys.add(model_id)
for routing_preference in model_provider.get("routing_preferences", []):
if routing_preference.get("name") in model_usage_name_keys:
raise Exception(
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
)
model_usage_name_keys.add(routing_preference.get("name"))
model_provider["model"] = model_id
model_provider["provider_interface"] = provider
model_provider_name_set.add(model_provider.get("name"))
if model_provider.get("provider") and model_provider.get(
"provider_interface"
):
raise Exception(
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
"Please provide either provider or provider_interface, not both"
)
provider = llm_provider.get("provider_interface", None)
elif llm_provider.get("provider_interface", None) is not None:
raise Exception(
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
)
if model_provider.get("provider"):
provider = model_provider["provider"]
model_provider["provider_interface"] = provider
del model_provider["provider"]
updated_model_providers.append(model_provider)
if model_id in model_name_keys:
raise Exception(
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
)
model_name_keys.add(model_id)
for routing_preference in llm_provider.get("routing_preferences", []):
if routing_preference.get("name") in model_usage_name_keys:
raise Exception(
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
)
model_usage_name_keys.add(routing_preference.get("name"))
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
updated_llm_providers.append(llm_provider)
if llm_provider.get("base_url", None):
base_url = llm_provider["base_url"]
urlparse_result = urlparse(base_url)
url_path = urlparse_result.path
if url_path and url_path != "/":
raise Exception(
f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'."
)
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
"http",
"https",
]:
raise Exception(
"Please provide a valid URL with scheme (http/https) in base_url"
)
protocol = urlparse_result.scheme
port = urlparse_result.port
if port is None:
if protocol == "http":
port = 80
else:
port = 443
endpoint = urlparse_result.hostname
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
if model_provider.get("base_url", None):
base_url = model_provider["base_url"]
urlparse_result = urlparse(base_url)
url_path = urlparse_result.path
if url_path and url_path != "/":
raise Exception(
f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'."
)
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
"http",
"https",
]:
raise Exception(
"Please provide a valid URL with scheme (http/https) in base_url"
)
protocol = urlparse_result.scheme
port = urlparse_result.port
if port is None:
if protocol == "http":
port = 80
else:
port = 443
endpoint = urlparse_result.hostname
model_provider["endpoint"] = endpoint
model_provider["port"] = port
model_provider["protocol"] = protocol
model_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(model_provider)
if len(model_usage_name_keys) > 0:
routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None)
if routing_llm_provider and routing_llm_provider not in llm_provider_name_set:
routing_model_provider = config_yaml.get("routing", {}).get(
"model_provider", None
)
if (
routing_model_provider
and routing_model_provider not in model_provider_name_set
):
raise Exception(
f"Routing llm_provider {routing_llm_provider} is not defined in llm_providers"
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
)
if routing_llm_provider is None and "arch-router" not in llm_provider_name_set:
updated_llm_providers.append(
if (
routing_model_provider is None
and "arch-router" not in model_provider_name_set
):
updated_model_providers.append(
{
"name": "arch-router",
"provider_interface": "arch",
@ -214,7 +291,19 @@ def validate_and_render_schema():
}
)
config_yaml["llm_providers"] = updated_llm_providers
updated_model_providers = []
for listener in listeners:
print("Processing listener: ", listener)
model_providers = listener.get("model_providers", None)
if model_providers is not None and model_providers != []:
print("processing egress traffic listener")
print("updated_model_providers: ", updated_model_providers)
if updated_model_providers is not None and updated_model_providers != []:
raise Exception(
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)
updated_model_providers = deepcopy(model_providers)
config_yaml["model_providers"] = updated_model_providers
# Validate model aliases if present
if "model_aliases" in config_yaml:
@ -223,30 +312,12 @@ def validate_and_render_schema():
target = alias_config.get("target")
if target not in model_name_keys:
raise Exception(
f"Model alias '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}"
f"Model alias 2 - '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}"
)
arch_config_string = yaml.dump(config_yaml)
arch_llm_config_string = yaml.dump(config_yaml)
prompt_gateway_listener = config_yaml.get("listeners", {}).get(
"ingress_traffic", {}
)
if prompt_gateway_listener.get("port") == None:
prompt_gateway_listener["port"] = 10000 # default port for prompt gateway
if prompt_gateway_listener.get("address") == None:
prompt_gateway_listener["address"] = "127.0.0.1"
if prompt_gateway_listener.get("timeout") == None:
prompt_gateway_listener["timeout"] = "10s"
llm_gateway_listener = config_yaml.get("listeners", {}).get("egress_traffic", {})
if llm_gateway_listener.get("port") == None:
llm_gateway_listener["port"] = 12000 # default port for llm gateway
if llm_gateway_listener.get("address") == None:
llm_gateway_listener["address"] = "127.0.0.1"
if llm_gateway_listener.get("timeout") == None:
llm_gateway_listener["timeout"] = "300s"
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
"use_agent_orchestrator", False
)
@ -269,15 +340,16 @@ def validate_and_render_schema():
print("agent_orchestrator: ", agent_orchestrator)
data = {
"prompt_gateway_listener": prompt_gateway_listener,
"llm_gateway_listener": llm_gateway_listener,
"prompt_gateway_listener": prompt_gateway,
"llm_gateway_listener": llm_gateway,
"arch_config": arch_config_string,
"arch_llm_config": arch_llm_config_string,
"arch_clusters": inferred_clusters,
"arch_llm_providers": config_yaml["llm_providers"],
"arch_model_providers": updated_model_providers,
"arch_tracing": arch_tracing,
"local_llms": llms_with_endpoint,
"agent_orchestrator": agent_orchestrator,
"listeners": listeners,
}
rendered = template.render(data)