mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more changes
This commit is contained in:
parent
ac68e802d8
commit
4f31edfaf5
1 changed files with 192 additions and 110 deletions
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
from cli.utils import convert_legacy_llm_providers
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import yaml
|
||||
from jsonschema import validate
|
||||
from urllib.parse import urlparse
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS = [
|
||||
|
|
@ -14,10 +16,6 @@ SUPPORTED_PROVIDERS = [
|
|||
"openai",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"azure_openai",
|
||||
"xai",
|
||||
"ollama",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -69,17 +67,49 @@ def validate_and_render_schema():
|
|||
_ = yaml.safe_load(arch_config_schema)
|
||||
inferred_clusters = {}
|
||||
|
||||
listeners, llm_gateway, prompt_gateway = convert_legacy_llm_providers(
|
||||
config_yaml.get("listeners"), config_yaml.get("llm_providers")
|
||||
)
|
||||
|
||||
config_yaml["listeners"] = listeners
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# Process agents section and convert to endpoints
|
||||
agents = config_yaml.get("agents", [])
|
||||
for agent in agents:
|
||||
agent_name = agent.get("name")
|
||||
agent_endpoint = agent.get("endpoint")
|
||||
|
||||
if agent_name and agent_endpoint:
|
||||
urlparse_result = urlparse(agent_endpoint)
|
||||
if urlparse_result.scheme and urlparse_result.hostname:
|
||||
protocol = urlparse_result.scheme
|
||||
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
endpoints[agent_name] = {
|
||||
"endpoint": urlparse_result.hostname,
|
||||
"port": port,
|
||||
"protocol": protocol,
|
||||
}
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
|
||||
if "port" not in endpoint_details:
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
|
|
@ -96,105 +126,162 @@ def validate_and_render_schema():
|
|||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
||||
llms_with_endpoint = []
|
||||
updated_llm_providers = []
|
||||
|
||||
updated_llm_providers = []
|
||||
llm_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
if llm_provider.get("name") in llm_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
)
|
||||
|
||||
model_name = llm_provider.get("model")
|
||||
if model_name in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
|
||||
)
|
||||
# # legacy listeners
|
||||
# # check if type is array or object
|
||||
# # if its dict its legacy format let's convert it to array
|
||||
# prompt_gateway_listener = {
|
||||
# "name": "ingress_traffic",
|
||||
# "port": 10000,
|
||||
# "address": "0.0.0.0",
|
||||
# "timeout": "30s",
|
||||
# "protocol": "openai",
|
||||
# }
|
||||
# llm_gateway_listener = {
|
||||
# "name": "egress_traffic",
|
||||
# "port": 12000,
|
||||
# "address": "0.0.0.0",
|
||||
# "timeout": "30s",
|
||||
# "llm_providers": [],
|
||||
# "protocol": "openai",
|
||||
# }
|
||||
# if isinstance(config_yaml["listeners"], dict):
|
||||
# ingress_traffic = config_yaml["listeners"].get("ingress_traffic", None)
|
||||
# egress_traffic = config_yaml["listeners"].get("egress_traffic", {})
|
||||
# config_yaml["listeners"] = []
|
||||
|
||||
model_name_keys.add(model_name)
|
||||
if llm_provider.get("name") is None:
|
||||
llm_provider["name"] = model_name
|
||||
# llm_providers = []
|
||||
# if config_yaml.get("llm_providers"):
|
||||
# llm_providers = config_yaml["llm_providers"]
|
||||
# del config_yaml["llm_providers"]
|
||||
# llm_gateway_listener["port"] = egress_traffic.get(
|
||||
# "port", llm_gateway_listener["port"]
|
||||
# )
|
||||
# llm_gateway_listener["address"] = egress_traffic.get(
|
||||
# "address", llm_gateway_listener["address"]
|
||||
# )
|
||||
# llm_gateway_listener["timeout"] = egress_traffic.get(
|
||||
# "timeout", llm_gateway_listener["timeout"]
|
||||
# )
|
||||
# llm_gateway_listener["llm_providers"] = llm_providers
|
||||
# config_yaml["listeners"].append(llm_gateway_listener)
|
||||
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
# if ingress_traffic:
|
||||
# prompt_gateway_listener["port"] = ingress_traffic.get(
|
||||
# "port", prompt_gateway_listener["port"]
|
||||
# )
|
||||
# prompt_gateway_listener["address"] = ingress_traffic.get(
|
||||
# "address", prompt_gateway_listener["address"]
|
||||
# )
|
||||
# prompt_gateway_listener["timeout"] = ingress_traffic.get(
|
||||
# "timeout", prompt_gateway_listener["timeout"]
|
||||
# )
|
||||
# config_yaml["listeners"].append(prompt_gateway_listener)
|
||||
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider == "azure_openai" or provider == "ollama") and llm_provider.get(
|
||||
"base_url"
|
||||
) is None:
|
||||
raise Exception(
|
||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||
)
|
||||
for listener in listeners:
|
||||
if listener.get("llm_providers") is None or listener.get("llm_providers") == []:
|
||||
continue
|
||||
print("Processing listener with llm_providers: ", listener)
|
||||
name = listener.get("name", None)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
llm_provider.get("base_url", None) is None
|
||||
or llm_provider.get("provider_interface", None) is None
|
||||
):
|
||||
for llm_provider in listener.get("llm_providers", []):
|
||||
if llm_provider.get("usage", None):
|
||||
llms_with_usage.append(llm_provider["name"])
|
||||
if llm_provider.get("name") in llm_provider_name_set:
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
)
|
||||
provider = llm_provider.get("provider_interface", None)
|
||||
elif llm_provider.get("provider_interface", None) is not None:
|
||||
raise Exception(
|
||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in llm_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
model_name = llm_provider.get("model")
|
||||
if model_name in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
model_name_keys.add(model_name)
|
||||
if llm_provider.get("name") is None:
|
||||
llm_provider["name"] = model_name
|
||||
|
||||
llm_provider["model"] = model_id
|
||||
llm_provider["provider_interface"] = provider
|
||||
updated_llm_providers.append(llm_provider)
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
llm_provider.get("base_url", None) is None
|
||||
or llm_provider.get("provider_interface", None) is None
|
||||
):
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = llm_provider.get("provider_interface", None)
|
||||
elif llm_provider.get("provider_interface", None) is not None:
|
||||
raise Exception(
|
||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if llm_provider.get("base_url", None):
|
||||
base_url = llm_provider["base_url"]
|
||||
urlparse_result = urlparse(base_url)
|
||||
url_path = urlparse_result.path
|
||||
if url_path and url_path != "/":
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'."
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
|
||||
)
|
||||
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
||||
"http",
|
||||
"https",
|
||||
]:
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in llm_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
|
||||
llm_provider["model"] = model_id
|
||||
llm_provider["provider_interface"] = provider
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
provider = None
|
||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
||||
raise Exception(
|
||||
"Please provide a valid URL with scheme (http/https) in base_url"
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
protocol = urlparse_result.scheme
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
endpoint = urlparse_result.hostname
|
||||
llm_provider["endpoint"] = endpoint
|
||||
llm_provider["port"] = port
|
||||
llm_provider["protocol"] = protocol
|
||||
llm_provider["cluster_name"] = (
|
||||
provider + "_" + endpoint
|
||||
) # make name unique by appending endpoint
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
if llm_provider.get("provider"):
|
||||
provider = llm_provider["provider"]
|
||||
llm_provider["provider_interface"] = provider
|
||||
del llm_provider["provider"]
|
||||
updated_llm_providers.append(llm_provider)
|
||||
|
||||
if llm_provider.get("base_url", None):
|
||||
base_url = llm_provider["base_url"]
|
||||
urlparse_result = urlparse(base_url)
|
||||
url_path = urlparse_result.path
|
||||
if url_path and url_path != "/":
|
||||
raise Exception(
|
||||
f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'."
|
||||
)
|
||||
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
||||
"http",
|
||||
"https",
|
||||
]:
|
||||
raise Exception(
|
||||
"Please provide a valid URL with scheme (http/https) in base_url"
|
||||
)
|
||||
protocol = urlparse_result.scheme
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
endpoint = urlparse_result.hostname
|
||||
llm_provider["endpoint"] = endpoint
|
||||
llm_provider["port"] = port
|
||||
llm_provider["protocol"] = protocol
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None)
|
||||
|
|
@ -211,6 +298,18 @@ def validate_and_render_schema():
|
|||
}
|
||||
)
|
||||
|
||||
updated_llm_providers = []
|
||||
for listener in listeners:
|
||||
print("Processing listener: ", listener)
|
||||
llm_providers = listener.get("llm_providers", None)
|
||||
if llm_providers is not None and llm_providers != []:
|
||||
print("processing egress traffic listener")
|
||||
print("updated_llm_providers: ", updated_llm_providers)
|
||||
if updated_llm_providers is not None and updated_llm_providers != []:
|
||||
raise Exception(
|
||||
"Please provide llm_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with llm_providers"
|
||||
)
|
||||
updated_llm_providers = deepcopy(llm_providers)
|
||||
config_yaml["llm_providers"] = updated_llm_providers
|
||||
|
||||
# Validate model aliases if present
|
||||
|
|
@ -226,24 +325,6 @@ def validate_and_render_schema():
|
|||
arch_config_string = yaml.dump(config_yaml)
|
||||
arch_llm_config_string = yaml.dump(config_yaml)
|
||||
|
||||
prompt_gateway_listener = config_yaml.get("listeners", {}).get(
|
||||
"ingress_traffic", {}
|
||||
)
|
||||
if prompt_gateway_listener.get("port") == None:
|
||||
prompt_gateway_listener["port"] = 10000 # default port for prompt gateway
|
||||
if prompt_gateway_listener.get("address") == None:
|
||||
prompt_gateway_listener["address"] = "127.0.0.1"
|
||||
if prompt_gateway_listener.get("timeout") == None:
|
||||
prompt_gateway_listener["timeout"] = "10s"
|
||||
|
||||
llm_gateway_listener = config_yaml.get("listeners", {}).get("egress_traffic", {})
|
||||
if llm_gateway_listener.get("port") == None:
|
||||
llm_gateway_listener["port"] = 12000 # default port for llm gateway
|
||||
if llm_gateway_listener.get("address") == None:
|
||||
llm_gateway_listener["address"] = "127.0.0.1"
|
||||
if llm_gateway_listener.get("timeout") == None:
|
||||
llm_gateway_listener["timeout"] = "10s"
|
||||
|
||||
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
|
||||
"use_agent_orchestrator", False
|
||||
)
|
||||
|
|
@ -266,15 +347,16 @@ def validate_and_render_schema():
|
|||
print("agent_orchestrator: ", agent_orchestrator)
|
||||
|
||||
data = {
|
||||
"prompt_gateway_listener": prompt_gateway_listener,
|
||||
"llm_gateway_listener": llm_gateway_listener,
|
||||
"prompt_gateway_listener": prompt_gateway,
|
||||
"llm_gateway_listener": llm_gateway,
|
||||
"arch_config": arch_config_string,
|
||||
"arch_llm_config": arch_llm_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_llm_providers": config_yaml["llm_providers"],
|
||||
"arch_llm_providers": updated_llm_providers,
|
||||
"arch_tracing": arch_tracing,
|
||||
"local_llms": llms_with_endpoint,
|
||||
"agent_orchestrator": agent_orchestrator,
|
||||
"listeners": listeners,
|
||||
}
|
||||
|
||||
rendered = template.render(data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue