mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Refactor config_generator into modular, testable components
Break the 507-line monolithic validate_and_render_schema() into focused modules with pure validation functions, proper error handling, and clean I/O separation: - config_providers.py: provider constants, ConfigValidationError, unified URL parsing (replaces 3 different inline implementations) - config_validator.py: 11 pure validation functions (no I/O, no print/exit) - config_generator.py: thin 146-line I/O orchestrator, reads files once (was twice), uses logging instead of print() Also cleans up module responsibilities: - Move stream_access_logs from utils.py to docker_cli.py (Docker operation) - Deduplicate llm_providers->model_providers migration - Fix "Model alias 2 -" debug artifact in error message - Update docker-compose.dev.yaml volume mounts for new files - Rewrite tests: 53 tests calling pure functions directly (no mock_open chains), up from 10 brittle mock-dependent tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3c8e899de3
commit
5774452195
8 changed files with 1304 additions and 876 deletions
|
|
@ -1,506 +1,145 @@
|
|||
import json
|
||||
"""Config generator: loads config files, validates, and renders Envoy template.
|
||||
|
||||
This module is the I/O boundary. It reads files, calls pure validation
|
||||
functions from config_validator and config_providers, then writes output.
|
||||
|
||||
Entry point: ``python -m planoai.config_generator`` (called by supervisord).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from planoai.utils import convert_legacy_listeners
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
import yaml
|
||||
from jsonschema import validate
|
||||
from urllib.parse import urlparse
|
||||
from copy import deepcopy
|
||||
from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
||||
"azure_openai",
|
||||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
"deepseek",
|
||||
"groq",
|
||||
"mistral",
|
||||
"openai",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"zhipu",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS = (
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL
|
||||
from planoai.config_providers import (
|
||||
ConfigValidationError,
|
||||
# Re-export for backward compatibility
|
||||
SUPPORTED_PROVIDERS,
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL,
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL,
|
||||
)
|
||||
from planoai.config_validator import (
|
||||
build_clusters,
|
||||
build_template_data,
|
||||
migrate_legacy_providers,
|
||||
process_model_providers,
|
||||
resolve_agent_orchestrator,
|
||||
validate_agents,
|
||||
validate_listeners,
|
||||
validate_model_aliases,
|
||||
validate_prompt_targets,
|
||||
validate_schema,
|
||||
validate_tracing,
|
||||
)
|
||||
from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
from planoai.utils import convert_legacy_listeners
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_endpoint_and_port(endpoint, protocol):
|
||||
endpoint_tokens = endpoint.split(":")
|
||||
if len(endpoint_tokens) > 1:
|
||||
endpoint = endpoint_tokens[0]
|
||||
port = int(endpoint_tokens[1])
|
||||
return endpoint, port
|
||||
else:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
return endpoint, port
|
||||
def load_yaml_file(path):
|
||||
"""Read a YAML file and return the parsed dict."""
|
||||
with open(path, "r") as f:
|
||||
raw = f.read()
|
||||
return yaml.safe_load(raw)
|
||||
|
||||
|
||||
def validate_and_render_schema():
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ARCH_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"""Main orchestrator: load -> validate -> process -> render -> write.
|
||||
|
||||
Reads env vars for file paths (Docker integration).
|
||||
Raises ConfigValidationError on validation failure.
|
||||
"""
|
||||
# --- Read environment config ---
|
||||
template_file = os.getenv("ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml")
|
||||
config_path = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
rendered_config_path = os.getenv(
|
||||
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
|
||||
)
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
envoy_rendered_path = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
schema_path = os.getenv("ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml")
|
||||
template_root = os.getenv("TEMPLATE_ROOT", "./")
|
||||
|
||||
env = Environment(loader=FileSystemLoader(os.getenv("TEMPLATE_ROOT", "./")))
|
||||
template = env.get_template(ENVOY_CONFIG_TEMPLATE_FILE)
|
||||
# --- Load files (each read exactly once) ---
|
||||
config = load_yaml_file(config_path)
|
||||
schema = load_yaml_file(schema_path)
|
||||
|
||||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
env = Environment(loader=FileSystemLoader(template_root))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
with open(ARCH_CONFIG_FILE, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
_ = yaml.safe_load(arch_config_schema)
|
||||
inferred_clusters = {}
|
||||
|
||||
# Convert legacy llm_providers to model_providers
|
||||
if "llm_providers" in config_yaml:
|
||||
if "model_providers" in config_yaml:
|
||||
raise Exception(
|
||||
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
|
||||
)
|
||||
config_yaml["model_providers"] = config_yaml["llm_providers"]
|
||||
del config_yaml["llm_providers"]
|
||||
# --- Validate and process ---
|
||||
validate_schema(config, schema)
|
||||
config = migrate_legacy_providers(config)
|
||||
|
||||
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||
config_yaml.get("listeners"), config_yaml.get("model_providers")
|
||||
config.get("listeners"), config.get("model_providers")
|
||||
)
|
||||
config["listeners"] = listeners
|
||||
|
||||
agent_endpoints = validate_agents(
|
||||
config.get("agents", []), config.get("filters", [])
|
||||
)
|
||||
clusters = build_clusters(config.get("endpoints", {}), agent_endpoints)
|
||||
log.info("Defined clusters: %s", clusters)
|
||||
|
||||
validate_prompt_targets(config, clusters)
|
||||
|
||||
tracing = validate_tracing(
|
||||
config.get("tracing", {}), DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
)
|
||||
|
||||
config_yaml["listeners"] = listeners
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# Process agents section and convert to endpoints
|
||||
agents = config_yaml.get("agents", [])
|
||||
filters = config_yaml.get("filters", [])
|
||||
agents_combined = agents + filters
|
||||
agent_id_keys = set()
|
||||
|
||||
for agent in agents_combined:
|
||||
agent_id = agent.get("id")
|
||||
if agent_id in agent_id_keys:
|
||||
raise Exception(
|
||||
f"Duplicate agent id {agent_id}, please provide unique id for each agent"
|
||||
)
|
||||
agent_id_keys.add(agent_id)
|
||||
agent_endpoint = agent.get("url")
|
||||
|
||||
if agent_id and agent_endpoint:
|
||||
urlparse_result = urlparse(agent_endpoint)
|
||||
if urlparse_result.scheme and urlparse_result.hostname:
|
||||
protocol = urlparse_result.scheme
|
||||
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
endpoints[agent_id] = {
|
||||
"endpoint": urlparse_result.hostname,
|
||||
"port": port,
|
||||
"protocol": protocol,
|
||||
}
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
|
||||
if "port" not in endpoint_details:
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
raise Exception(
|
||||
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
||||
# Resolution order: config yaml > OTEL_TRACING_GRPC_ENDPOINT env var > hardcoded default
|
||||
opentracing_grpc_endpoint = arch_tracing.get(
|
||||
"opentracing_grpc_endpoint",
|
||||
os.environ.get(
|
||||
"OTEL_TRACING_GRPC_ENDPOINT", DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
),
|
||||
updated_providers, llms_with_endpoint, model_name_keys = process_model_providers(
|
||||
listeners, config.get("routing", {})
|
||||
)
|
||||
# resolve env vars in opentracing_grpc_endpoint if present
|
||||
if opentracing_grpc_endpoint and "$" in opentracing_grpc_endpoint:
|
||||
opentracing_grpc_endpoint = os.path.expandvars(opentracing_grpc_endpoint)
|
||||
print(
|
||||
f"Resolved opentracing_grpc_endpoint to {opentracing_grpc_endpoint} after expanding environment variables"
|
||||
)
|
||||
arch_tracing["opentracing_grpc_endpoint"] = opentracing_grpc_endpoint
|
||||
# ensure that opentracing_grpc_endpoint is a valid URL if present and start with http and must not have any path
|
||||
if opentracing_grpc_endpoint:
|
||||
urlparse_result = urlparse(opentracing_grpc_endpoint)
|
||||
if urlparse_result.scheme != "http":
|
||||
raise Exception(
|
||||
f"Invalid opentracing_grpc_endpoint {opentracing_grpc_endpoint}, scheme must be http"
|
||||
)
|
||||
if urlparse_result.path and urlparse_result.path != "/":
|
||||
raise Exception(
|
||||
f"Invalid opentracing_grpc_endpoint {opentracing_grpc_endpoint}, path must be empty"
|
||||
)
|
||||
config["model_providers"] = deepcopy(updated_providers)
|
||||
|
||||
llms_with_endpoint = []
|
||||
llms_with_endpoint_cluster_names = set()
|
||||
updated_model_providers = []
|
||||
model_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
validate_listeners(listeners)
|
||||
|
||||
print("listeners: ", listeners)
|
||||
if "model_aliases" in config:
|
||||
validate_model_aliases(config["model_aliases"], model_name_keys)
|
||||
|
||||
for listener in listeners:
|
||||
if (
|
||||
listener.get("model_providers") is None
|
||||
or listener.get("model_providers") == []
|
||||
):
|
||||
continue
|
||||
print("Processing listener with model_providers: ", listener)
|
||||
name = listener.get("name", None)
|
||||
|
||||
for model_provider in listener.get("model_providers", []):
|
||||
if model_provider.get("usage", None):
|
||||
llms_with_usage.append(model_provider["name"])
|
||||
if model_provider.get("name") in model_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider"
|
||||
)
|
||||
|
||||
model_name = model_provider.get("model")
|
||||
print("Processing model_provider: ", model_provider)
|
||||
|
||||
# Check if this is a wildcard model (provider/*)
|
||||
is_wildcard = False
|
||||
if "/" in model_name:
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) >= 2 and model_name_tokens[-1] == "*":
|
||||
is_wildcard = True
|
||||
|
||||
if model_name in model_name_keys and not is_wildcard:
|
||||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
|
||||
)
|
||||
|
||||
if not is_wildcard:
|
||||
model_name_keys.add(model_name)
|
||||
if model_provider.get("name") is None:
|
||||
model_provider["name"] = model_name
|
||||
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id> or <provider>/* for wildcards."
|
||||
)
|
||||
provider = model_name_tokens[0].strip()
|
||||
|
||||
# Check if this is a wildcard (provider/*)
|
||||
is_wildcard = model_name_tokens[-1].strip() == "*"
|
||||
|
||||
# Validate wildcard constraints
|
||||
if is_wildcard:
|
||||
if model_provider.get("default", False):
|
||||
raise Exception(
|
||||
f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards."
|
||||
)
|
||||
if model_provider.get("routing_preferences"):
|
||||
raise Exception(
|
||||
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
"base_url"
|
||||
) is None:
|
||||
raise Exception(
|
||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||
)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
|
||||
# For wildcard providers, allow any provider name
|
||||
if not is_wildcard and provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
model_provider.get("base_url", None) is None
|
||||
or model_provider.get("provider_interface", None) is None
|
||||
):
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = model_provider.get("provider_interface", None)
|
||||
elif is_wildcard and provider not in SUPPORTED_PROVIDERS:
|
||||
# Wildcard models with unsupported providers require base_url and provider_interface
|
||||
if (
|
||||
model_provider.get("base_url", None) is None
|
||||
or model_provider.get("provider_interface", None) is None
|
||||
):
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for wildcard model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = model_provider.get("provider_interface", None)
|
||||
elif (
|
||||
provider in SUPPORTED_PROVIDERS
|
||||
and model_provider.get("provider_interface", None) is not None
|
||||
):
|
||||
# For supported providers, provider_interface should not be manually set
|
||||
raise Exception(
|
||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
# For wildcard models, don't add model_id to the keys since it's "*"
|
||||
if not is_wildcard:
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
raise Exception(
|
||||
f'Duplicate routing preference name "{routing_preference.get("name")}", please provide unique name for each routing preference'
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
|
||||
# Warn if both passthrough_auth and access_key are configured
|
||||
if model_provider.get("passthrough_auth") and model_provider.get(
|
||||
"access_key"
|
||||
):
|
||||
print(
|
||||
f"WARNING: Model provider '{model_provider.get('name')}' has both 'passthrough_auth: true' and 'access_key' configured. "
|
||||
f"The access_key will be ignored and the client's Authorization header will be forwarded instead."
|
||||
)
|
||||
|
||||
model_provider["model"] = model_id
|
||||
model_provider["provider_interface"] = provider
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
if model_provider.get("provider") and model_provider.get(
|
||||
"provider_interface"
|
||||
):
|
||||
raise Exception(
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
if model_provider.get("provider"):
|
||||
provider = model_provider["provider"]
|
||||
model_provider["provider_interface"] = provider
|
||||
del model_provider["provider"]
|
||||
updated_model_providers.append(model_provider)
|
||||
|
||||
if model_provider.get("base_url", None):
|
||||
base_url = model_provider["base_url"]
|
||||
urlparse_result = urlparse(base_url)
|
||||
base_url_path_prefix = urlparse_result.path
|
||||
if base_url_path_prefix and base_url_path_prefix != "/":
|
||||
# we will now support base_url_path_prefix. This means that the user can provide base_url like http://example.com/path and we will extract /path as base_url_path_prefix
|
||||
model_provider["base_url_path_prefix"] = base_url_path_prefix
|
||||
|
||||
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
||||
"http",
|
||||
"https",
|
||||
]:
|
||||
raise Exception(
|
||||
"Please provide a valid URL with scheme (http/https) in base_url"
|
||||
)
|
||||
protocol = urlparse_result.scheme
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
endpoint = urlparse_result.hostname
|
||||
model_provider["endpoint"] = endpoint
|
||||
model_provider["port"] = port
|
||||
model_provider["protocol"] = protocol
|
||||
cluster_name = (
|
||||
provider + "_" + endpoint
|
||||
) # make name unique by appending endpoint
|
||||
model_provider["cluster_name"] = cluster_name
|
||||
# Only add if cluster_name is not already present to avoid duplicates
|
||||
if cluster_name not in llms_with_endpoint_cluster_names:
|
||||
llms_with_endpoint.append(model_provider)
|
||||
llms_with_endpoint_cluster_names.add(cluster_name)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_model_provider = config_yaml.get("routing", {}).get(
|
||||
"model_provider", None
|
||||
)
|
||||
if (
|
||||
routing_model_provider
|
||||
and routing_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise Exception(
|
||||
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
|
||||
)
|
||||
if (
|
||||
routing_model_provider is None
|
||||
and "arch-router" not in model_provider_name_set
|
||||
):
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": config_yaml.get("routing", {}).get("model", "Arch-Router"),
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Always add arch-function model provider if not already defined
|
||||
if "arch-function" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-function",
|
||||
"provider_interface": "arch",
|
||||
"model": "Arch-Function",
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
if "plano-orchestrator" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "arch",
|
||||
"model": "Plano-Orchestrator",
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
config_yaml["model_providers"] = deepcopy(updated_model_providers)
|
||||
|
||||
listeners_with_provider = 0
|
||||
for listener in listeners:
|
||||
print("Processing listener: ", listener)
|
||||
model_providers = listener.get("model_providers", None)
|
||||
if model_providers is not None:
|
||||
listeners_with_provider += 1
|
||||
if listeners_with_provider > 1:
|
||||
raise Exception(
|
||||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||
)
|
||||
|
||||
# Validate model aliases if present
|
||||
if "model_aliases" in config_yaml:
|
||||
model_aliases = config_yaml["model_aliases"]
|
||||
for alias_name, alias_config in model_aliases.items():
|
||||
target = alias_config.get("target")
|
||||
if target not in model_name_keys:
|
||||
raise Exception(
|
||||
f"Model alias 2 - '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}"
|
||||
)
|
||||
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
arch_llm_config_string = yaml.dump(config_yaml)
|
||||
|
||||
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
|
||||
"use_agent_orchestrator", False
|
||||
agent_orchestrator = resolve_agent_orchestrator(
|
||||
config, config.get("endpoints", {})
|
||||
)
|
||||
|
||||
agent_orchestrator = None
|
||||
if use_agent_orchestrator:
|
||||
print("Using agent orchestrator")
|
||||
|
||||
if len(endpoints) == 0:
|
||||
raise Exception(
|
||||
"Please provide agent orchestrator in the endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
elif len(endpoints) > 1:
|
||||
raise Exception(
|
||||
"Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
else:
|
||||
agent_orchestrator = list(endpoints.keys())[0]
|
||||
|
||||
print("agent_orchestrator: ", agent_orchestrator)
|
||||
|
||||
data = {
|
||||
"prompt_gateway_listener": prompt_gateway,
|
||||
"llm_gateway_listener": llm_gateway,
|
||||
"arch_config": arch_config_string,
|
||||
"arch_llm_config": arch_llm_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_model_providers": updated_model_providers,
|
||||
"arch_tracing": arch_tracing,
|
||||
"local_llms": llms_with_endpoint,
|
||||
"agent_orchestrator": agent_orchestrator,
|
||||
"listeners": listeners,
|
||||
}
|
||||
data = build_template_data(
|
||||
prompt_gateway,
|
||||
llm_gateway,
|
||||
config,
|
||||
clusters,
|
||||
updated_providers,
|
||||
tracing,
|
||||
llms_with_endpoint,
|
||||
agent_orchestrator,
|
||||
listeners,
|
||||
)
|
||||
|
||||
# --- Render and write ---
|
||||
rendered = template.render(data)
|
||||
print(ENVOY_CONFIG_FILE_RENDERED)
|
||||
print(rendered)
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
log.info("Writing Envoy config to %s", envoy_rendered_path)
|
||||
|
||||
with open(ARCH_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(arch_config_string)
|
||||
with open(envoy_rendered_path, "w") as f:
|
||||
f.write(rendered)
|
||||
|
||||
|
||||
def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(arch_config_schema_file, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
||||
|
||||
try:
|
||||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e}"
|
||||
)
|
||||
raise e
|
||||
config_string = yaml.dump(config)
|
||||
with open(rendered_config_path, "w") as f:
|
||||
f.write(config_string)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate_and_render_schema()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
try:
|
||||
validate_and_render_schema()
|
||||
except ConfigValidationError as e:
|
||||
log.error(str(e))
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
log.error("Unexpected error: %s", e)
|
||||
exit(1)
|
||||
|
|
|
|||
87
cli/planoai/config_providers.py
Normal file
87
cli/planoai/config_providers.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Model provider constants, custom exception, and URL parsing utility."""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when config validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# --- Provider Constants ---
|
||||
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
||||
"azure_openai",
|
||||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
"deepseek",
|
||||
"groq",
|
||||
"mistral",
|
||||
"openai",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"zhipu",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS = (
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL
|
||||
)
|
||||
|
||||
INTERNAL_PROVIDERS = {
|
||||
"arch-function": {
|
||||
"name": "arch-function",
|
||||
"provider_interface": "arch",
|
||||
"model": "Arch-Function",
|
||||
"internal": True,
|
||||
},
|
||||
"plano-orchestrator": {
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "arch",
|
||||
"model": "Plano-Orchestrator",
|
||||
"internal": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def parse_url_endpoint(url):
|
||||
"""Parse a URL into endpoint, port, protocol, and optional path_prefix.
|
||||
|
||||
Replaces the old get_endpoint_and_port() and inline urlparse logic.
|
||||
Raises ConfigValidationError for invalid URLs.
|
||||
|
||||
Returns dict with keys: endpoint, port, protocol, path_prefix (optional)
|
||||
"""
|
||||
result = urlparse(url)
|
||||
if not result.scheme or result.scheme not in ("http", "https"):
|
||||
raise ConfigValidationError(
|
||||
f"Invalid URL '{url}': scheme must be http or https"
|
||||
)
|
||||
if not result.hostname:
|
||||
raise ConfigValidationError(f"Invalid URL '{url}': hostname is required")
|
||||
|
||||
port = result.port
|
||||
if port is None:
|
||||
port = 80 if result.scheme == "http" else 443
|
||||
|
||||
parsed = {
|
||||
"endpoint": result.hostname,
|
||||
"port": port,
|
||||
"protocol": result.scheme,
|
||||
}
|
||||
|
||||
if result.path and result.path != "/":
|
||||
parsed["path_prefix"] = result.path
|
||||
|
||||
return parsed
|
||||
486
cli/planoai/config_validator.py
Normal file
486
cli/planoai/config_validator.py
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
"""Pure validation and transformation functions for Plano config.
|
||||
|
||||
Every function in this module takes data in, returns data out, and raises
|
||||
ConfigValidationError on failure. No file I/O, no print(), no exit().
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml
|
||||
from jsonschema import validate as jsonschema_validate
|
||||
|
||||
from planoai.config_providers import (
|
||||
INTERNAL_PROVIDERS,
|
||||
SUPPORTED_PROVIDERS,
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL,
|
||||
ConfigValidationError,
|
||||
parse_url_endpoint,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_schema(config, schema):
|
||||
"""Validate config dict against JSON schema dict.
|
||||
|
||||
Raises ConfigValidationError with a clear message on failure.
|
||||
"""
|
||||
try:
|
||||
jsonschema_validate(config, schema)
|
||||
except Exception as e:
|
||||
raise ConfigValidationError(f"Schema validation failed: {e}") from e
|
||||
|
||||
|
||||
def migrate_legacy_providers(config):
|
||||
"""Migrate llm_providers -> model_providers if needed.
|
||||
|
||||
Returns a new config dict (does not mutate input).
|
||||
Raises ConfigValidationError if both are present.
|
||||
"""
|
||||
config = deepcopy(config)
|
||||
|
||||
if "llm_providers" in config:
|
||||
if "model_providers" in config:
|
||||
raise ConfigValidationError(
|
||||
"Please provide either llm_providers or model_providers, not both. "
|
||||
"llm_providers is deprecated, please use model_providers instead"
|
||||
)
|
||||
config["model_providers"] = config.pop("llm_providers")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def validate_agents(agents, filters):
|
||||
"""Validate agent/filter entries and infer endpoint clusters from URLs.
|
||||
|
||||
Returns dict of inferred endpoint clusters keyed by agent_id.
|
||||
Raises ConfigValidationError on duplicate IDs.
|
||||
"""
|
||||
combined = agents + filters
|
||||
seen_ids = set()
|
||||
inferred_endpoints = {}
|
||||
|
||||
for agent in combined:
|
||||
agent_id = agent.get("id")
|
||||
if agent_id in seen_ids:
|
||||
raise ConfigValidationError(
|
||||
f"Duplicate agent id {agent_id}, please provide unique id for each agent"
|
||||
)
|
||||
seen_ids.add(agent_id)
|
||||
|
||||
agent_url = agent.get("url")
|
||||
if agent_id and agent_url:
|
||||
result = urlparse(agent_url)
|
||||
if result.scheme and result.hostname:
|
||||
port = result.port
|
||||
if port is None:
|
||||
port = 80 if result.scheme == "http" else 443
|
||||
|
||||
inferred_endpoints[agent_id] = {
|
||||
"endpoint": result.hostname,
|
||||
"port": port,
|
||||
"protocol": result.scheme,
|
||||
}
|
||||
|
||||
return inferred_endpoints
|
||||
|
||||
|
||||
def build_clusters(endpoints, agent_inferred):
|
||||
"""Merge explicit endpoints with agent-inferred clusters.
|
||||
|
||||
Returns the final cluster dict.
|
||||
"""
|
||||
clusters = dict(agent_inferred)
|
||||
|
||||
for name, endpoint_details in endpoints.items():
|
||||
clusters[name] = dict(endpoint_details)
|
||||
# Resolve port for manually defined endpoints that lack one
|
||||
if "port" not in clusters[name]:
|
||||
endpoint = clusters[name]["endpoint"]
|
||||
protocol = clusters[name].get("protocol", "http")
|
||||
if ":" in endpoint:
|
||||
parts = endpoint.split(":")
|
||||
clusters[name]["endpoint"] = parts[0]
|
||||
clusters[name]["port"] = int(parts[1])
|
||||
else:
|
||||
clusters[name]["port"] = 80 if protocol == "http" else 443
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def validate_prompt_targets(config, clusters):
|
||||
"""Validate that prompt_targets reference valid endpoints."""
|
||||
for prompt_target in config.get("prompt_targets", []):
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in clusters:
|
||||
raise ConfigValidationError(
|
||||
f"Unknown endpoint {name}, please add it in endpoints section "
|
||||
"in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
|
||||
def validate_tracing(tracing_config, default_endpoint):
|
||||
"""Validate and resolve the tracing configuration.
|
||||
|
||||
Handles env var resolution for opentracing_grpc_endpoint.
|
||||
Returns the resolved tracing dict.
|
||||
Raises ConfigValidationError for invalid endpoints.
|
||||
"""
|
||||
tracing = deepcopy(tracing_config)
|
||||
|
||||
# Resolution order: config yaml > OTEL_TRACING_GRPC_ENDPOINT env var > default
|
||||
endpoint = tracing.get(
|
||||
"opentracing_grpc_endpoint",
|
||||
os.environ.get("OTEL_TRACING_GRPC_ENDPOINT", default_endpoint),
|
||||
)
|
||||
|
||||
# Resolve env var references like $VAR or ${VAR}
|
||||
if endpoint and "$" in endpoint:
|
||||
endpoint = os.path.expandvars(endpoint)
|
||||
log.info("Resolved opentracing_grpc_endpoint to %s", endpoint)
|
||||
|
||||
tracing["opentracing_grpc_endpoint"] = endpoint
|
||||
|
||||
if endpoint:
|
||||
result = urlparse(endpoint)
|
||||
if result.scheme != "http":
|
||||
raise ConfigValidationError(
|
||||
f"Invalid opentracing_grpc_endpoint {endpoint}, scheme must be http"
|
||||
)
|
||||
if result.path and result.path != "/":
|
||||
raise ConfigValidationError(
|
||||
f"Invalid opentracing_grpc_endpoint {endpoint}, path must be empty"
|
||||
)
|
||||
|
||||
return tracing
|
||||
|
||||
|
||||
def process_model_providers(listeners, routing_config):
|
||||
"""Process all model providers from listeners.
|
||||
|
||||
Validates names, models, provider interfaces, base_urls, wildcards,
|
||||
routing preferences, and injects internal providers.
|
||||
|
||||
Args:
|
||||
listeners: List of listener dicts from config.
|
||||
routing_config: The 'routing' section from config (may be empty dict).
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_model_providers, llms_with_endpoint, model_name_keys).
|
||||
|
||||
Raises:
|
||||
ConfigValidationError on any validation failure.
|
||||
"""
|
||||
llms_with_endpoint = []
|
||||
llms_with_endpoint_cluster_names = set()
|
||||
updated_model_providers = []
|
||||
model_provider_name_set = set()
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
|
||||
for listener in listeners:
|
||||
if not listener.get("model_providers"):
|
||||
continue
|
||||
|
||||
for model_provider in listener.get("model_providers", []):
|
||||
_validate_and_process_single_provider(
|
||||
model_provider,
|
||||
model_name_keys,
|
||||
model_provider_name_set,
|
||||
model_usage_name_keys,
|
||||
updated_model_providers,
|
||||
llms_with_endpoint,
|
||||
llms_with_endpoint_cluster_names,
|
||||
)
|
||||
|
||||
# Inject internal providers
|
||||
_inject_internal_providers(
|
||||
updated_model_providers,
|
||||
model_provider_name_set,
|
||||
model_usage_name_keys,
|
||||
routing_config,
|
||||
)
|
||||
|
||||
return updated_model_providers, llms_with_endpoint, model_name_keys
|
||||
|
||||
|
||||
def _validate_and_process_single_provider(
|
||||
model_provider,
|
||||
model_name_keys,
|
||||
model_provider_name_set,
|
||||
model_usage_name_keys,
|
||||
updated_model_providers,
|
||||
llms_with_endpoint,
|
||||
llms_with_endpoint_cluster_names,
|
||||
):
|
||||
"""Validate and normalize a single model_provider entry."""
|
||||
# Check duplicate provider name
|
||||
if model_provider.get("name") in model_provider_name_set:
|
||||
raise ConfigValidationError(
|
||||
f"Duplicate model_provider name {model_provider.get('name')}, "
|
||||
"please provide unique name for each model_provider"
|
||||
)
|
||||
|
||||
model_name = model_provider.get("model")
|
||||
|
||||
# Parse model name into provider/model_id
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise ConfigValidationError(
|
||||
f"Invalid model name {model_name}. Please provide model name in the "
|
||||
"format <provider>/<model_id> or <provider>/* for wildcards."
|
||||
)
|
||||
|
||||
provider = model_name_tokens[0].strip()
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
is_wildcard = model_name_tokens[-1].strip() == "*"
|
||||
|
||||
# Check duplicate model name (non-wildcard only)
|
||||
if model_name in model_name_keys and not is_wildcard:
|
||||
raise ConfigValidationError(
|
||||
f"Duplicate model name {model_name}, please provide unique model "
|
||||
"name for each model_provider"
|
||||
)
|
||||
|
||||
if not is_wildcard:
|
||||
model_name_keys.add(model_name)
|
||||
|
||||
# Auto-name if not provided
|
||||
if model_provider.get("name") is None:
|
||||
model_provider["name"] = model_name
|
||||
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
|
||||
# Validate wildcard constraints
|
||||
if is_wildcard:
|
||||
if model_provider.get("default", False):
|
||||
raise ConfigValidationError(
|
||||
f"Model {model_name} is configured as default but uses wildcard (*). "
|
||||
"Default models cannot be wildcards."
|
||||
)
|
||||
if model_provider.get("routing_preferences"):
|
||||
raise ConfigValidationError(
|
||||
f"Model {model_name} has routing_preferences but uses wildcard (*). "
|
||||
"Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
# Validate provider requires base_url
|
||||
if provider in SUPPORTED_PROVIDERS_WITH_BASE_URL and not model_provider.get(
|
||||
"base_url"
|
||||
):
|
||||
raise ConfigValidationError(
|
||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||
)
|
||||
|
||||
# Resolve provider interface
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if not model_provider.get("base_url") or not model_provider.get(
|
||||
"provider_interface"
|
||||
):
|
||||
raise ConfigValidationError(
|
||||
f"Must provide base_url and provider_interface for unsupported "
|
||||
f"provider {provider} for {'wildcard ' if is_wildcard else ''}model "
|
||||
f"{model_name}. Supported providers are: "
|
||||
f"{', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = model_provider.get("provider_interface")
|
||||
elif model_provider.get("provider_interface") is not None:
|
||||
raise ConfigValidationError(
|
||||
f"Please provide provider interface as part of model name {model_name} "
|
||||
"using the format <provider>/<model_id>. For example, use "
|
||||
"'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
# Check duplicate model_id (non-wildcard only)
|
||||
if not is_wildcard:
|
||||
if model_id in model_name_keys:
|
||||
raise ConfigValidationError(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id "
|
||||
"for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
# Validate routing preferences
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
pref_name = routing_preference.get("name")
|
||||
if pref_name in model_usage_name_keys:
|
||||
raise ConfigValidationError(
|
||||
f'Duplicate routing preference name "{pref_name}", please provide '
|
||||
"unique name for each routing preference"
|
||||
)
|
||||
model_usage_name_keys.add(pref_name)
|
||||
|
||||
# Warn if both passthrough_auth and access_key are configured
|
||||
if model_provider.get("passthrough_auth") and model_provider.get("access_key"):
|
||||
log.warning(
|
||||
"Model provider '%s' has both 'passthrough_auth: true' and 'access_key' "
|
||||
"configured. The access_key will be ignored and the client's Authorization "
|
||||
"header will be forwarded instead.",
|
||||
model_provider.get("name"),
|
||||
)
|
||||
|
||||
# Normalize provider fields
|
||||
model_provider["model"] = model_id
|
||||
model_provider["provider_interface"] = provider
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
|
||||
if model_provider.get("provider") and model_provider.get("provider_interface"):
|
||||
raise ConfigValidationError(
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
if model_provider.get("provider"):
|
||||
provider = model_provider["provider"]
|
||||
model_provider["provider_interface"] = provider
|
||||
del model_provider["provider"]
|
||||
|
||||
updated_model_providers.append(model_provider)
|
||||
|
||||
# Process base_url into cluster endpoint info
|
||||
if model_provider.get("base_url"):
|
||||
_process_base_url(
|
||||
model_provider,
|
||||
provider,
|
||||
llms_with_endpoint,
|
||||
llms_with_endpoint_cluster_names,
|
||||
)
|
||||
|
||||
|
||||
def _process_base_url(
|
||||
model_provider, provider, llms_with_endpoint, llms_with_endpoint_cluster_names
|
||||
):
|
||||
"""Parse base_url and add cluster endpoint info to the model provider."""
|
||||
base_url = model_provider["base_url"]
|
||||
parsed = parse_url_endpoint(base_url)
|
||||
|
||||
if parsed.get("path_prefix"):
|
||||
model_provider["base_url_path_prefix"] = parsed["path_prefix"]
|
||||
|
||||
model_provider["endpoint"] = parsed["endpoint"]
|
||||
model_provider["port"] = parsed["port"]
|
||||
model_provider["protocol"] = parsed["protocol"]
|
||||
|
||||
cluster_name = provider + "_" + parsed["endpoint"]
|
||||
model_provider["cluster_name"] = cluster_name
|
||||
|
||||
if cluster_name not in llms_with_endpoint_cluster_names:
|
||||
llms_with_endpoint.append(model_provider)
|
||||
llms_with_endpoint_cluster_names.add(cluster_name)
|
||||
|
||||
|
||||
def _inject_internal_providers(
|
||||
updated_model_providers,
|
||||
model_provider_name_set,
|
||||
model_usage_name_keys,
|
||||
routing_config,
|
||||
):
|
||||
"""Add arch-router, arch-function, plano-orchestrator if not already defined."""
|
||||
# Add arch-router if routing preferences exist and no router is configured
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_model_provider = routing_config.get("model_provider", None)
|
||||
if (
|
||||
routing_model_provider
|
||||
and routing_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise ConfigValidationError(
|
||||
f"Routing model_provider {routing_model_provider} is not defined "
|
||||
"in model_providers"
|
||||
)
|
||||
if (
|
||||
routing_model_provider is None
|
||||
and "arch-router" not in model_provider_name_set
|
||||
):
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": routing_config.get("model", "Arch-Router"),
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
for name, provider_def in INTERNAL_PROVIDERS.items():
|
||||
if name not in model_provider_name_set:
|
||||
updated_model_providers.append(dict(provider_def))
|
||||
|
||||
|
||||
def validate_listeners(listeners):
|
||||
"""Validate that at most one listener has model_providers."""
|
||||
count = sum(1 for l in listeners if l.get("model_providers") is not None)
|
||||
if count > 1:
|
||||
raise ConfigValidationError(
|
||||
"Please provide model_providers either under listeners or at root level, "
|
||||
"not both. Currently we don't support multiple listeners with model_providers"
|
||||
)
|
||||
|
||||
|
||||
def validate_model_aliases(aliases, model_name_keys):
|
||||
"""Validate that model aliases reference existing models."""
|
||||
for alias_name, alias_config in aliases.items():
|
||||
target = alias_config.get("target")
|
||||
if target not in model_name_keys:
|
||||
raise ConfigValidationError(
|
||||
f"Model alias '{alias_name}' targets '{target}' which is not "
|
||||
f"defined as a model. Available models: "
|
||||
f"{', '.join(sorted(model_name_keys))}"
|
||||
)
|
||||
|
||||
|
||||
def resolve_agent_orchestrator(config, endpoints):
|
||||
"""Resolve agent orchestrator from config overrides.
|
||||
|
||||
Returns the orchestrator endpoint name, or None if not configured.
|
||||
"""
|
||||
use_orchestrator = config.get("overrides", {}).get(
|
||||
"use_agent_orchestrator", False
|
||||
)
|
||||
if not use_orchestrator:
|
||||
return None
|
||||
|
||||
if len(endpoints) == 0:
|
||||
raise ConfigValidationError(
|
||||
"Please provide agent orchestrator in the endpoints section "
|
||||
"in your arch_config.yaml file"
|
||||
)
|
||||
if len(endpoints) > 1:
|
||||
raise ConfigValidationError(
|
||||
"Please provide single agent orchestrator in the endpoints section "
|
||||
"in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
return list(endpoints.keys())[0]
|
||||
|
||||
|
||||
def build_template_data(
|
||||
prompt_gateway,
|
||||
llm_gateway,
|
||||
config_yaml,
|
||||
clusters,
|
||||
model_providers,
|
||||
tracing,
|
||||
llms_with_endpoint,
|
||||
agent_orchestrator,
|
||||
listeners,
|
||||
):
|
||||
"""Assemble the Jinja2 template rendering context.
|
||||
|
||||
Note: arch_config and arch_llm_config are intentionally the same value.
|
||||
Both are kept for backward compatibility with the Envoy template.
|
||||
"""
|
||||
config_string = yaml.dump(config_yaml)
|
||||
return {
|
||||
"prompt_gateway_listener": prompt_gateway,
|
||||
"llm_gateway_listener": llm_gateway,
|
||||
"arch_config": config_string,
|
||||
"arch_llm_config": config_string,
|
||||
"arch_clusters": clusters,
|
||||
"arch_model_providers": model_providers,
|
||||
"arch_tracing": tracing,
|
||||
"local_llms": llms_with_endpoint,
|
||||
"agent_orchestrator": agent_orchestrator,
|
||||
"listeners": listeners,
|
||||
}
|
||||
|
|
@ -115,6 +115,28 @@ def stream_gateway_logs(follow, service="plano"):
|
|||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""Stream access logs from the running Plano container."""
|
||||
|
||||
follow_arg = "-f" if follow else ""
|
||||
|
||||
stream_command = [
|
||||
"docker",
|
||||
"exec",
|
||||
PLANO_DOCKER_NAME,
|
||||
"sh",
|
||||
"-c",
|
||||
f"tail {follow_arg} /var/log/access_*.log",
|
||||
]
|
||||
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def docker_validate_plano_schema(arch_config_file):
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ PLANO_COLOR = "#969FF4"
|
|||
from planoai.docker_cli import (
|
||||
docker_validate_plano_schema,
|
||||
stream_gateway_logs,
|
||||
stream_access_logs,
|
||||
docker_container_status,
|
||||
)
|
||||
from planoai.utils import (
|
||||
|
|
@ -17,7 +18,6 @@ from planoai.utils import (
|
|||
get_llm_provider_access_keys,
|
||||
load_env_file_to_dict,
|
||||
set_log_level,
|
||||
stream_access_logs,
|
||||
find_config_file,
|
||||
find_repo_root,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
import logging
|
||||
from planoai.consts import PLANO_DOCKER_NAME
|
||||
|
||||
|
||||
# Standard env var for log level across all Plano components
|
||||
|
|
@ -162,20 +158,15 @@ def convert_legacy_listeners(
|
|||
|
||||
|
||||
def get_llm_provider_access_keys(arch_config_file):
|
||||
from planoai.config_validator import migrate_legacy_providers
|
||||
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
arch_config_yaml = yaml.safe_load(arch_config)
|
||||
|
||||
access_key_list = []
|
||||
|
||||
# Convert legacy llm_providers to model_providers
|
||||
if "llm_providers" in arch_config_yaml:
|
||||
if "model_providers" in arch_config_yaml:
|
||||
raise Exception(
|
||||
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
|
||||
)
|
||||
arch_config_yaml["model_providers"] = arch_config_yaml["llm_providers"]
|
||||
del arch_config_yaml["llm_providers"]
|
||||
arch_config_yaml = migrate_legacy_providers(arch_config_yaml)
|
||||
|
||||
listeners, _, _ = convert_legacy_listeners(
|
||||
arch_config_yaml.get("listeners"), arch_config_yaml.get("model_providers")
|
||||
|
|
@ -258,25 +249,3 @@ def find_config_file(path=".", file=None):
|
|||
return arch_config_file
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
"""
|
||||
|
||||
follow_arg = "-f" if follow else ""
|
||||
|
||||
stream_command = [
|
||||
"docker",
|
||||
"exec",
|
||||
PLANO_DOCKER_NAME,
|
||||
"sh",
|
||||
"-c",
|
||||
f"tail {follow_arg} /var/log/access_*.log",
|
||||
]
|
||||
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -13,6 +13,8 @@ services:
|
|||
- ./envoy.template.yaml:/app/envoy.template.yaml
|
||||
- ./arch_config_schema.yaml:/app/arch_config_schema.yaml
|
||||
- ../cli/planoai/config_generator.py:/app/planoai/config_generator.py
|
||||
- ../cli/planoai/config_validator.py:/app/planoai/config_validator.py
|
||||
- ../cli/planoai/config_providers.py:/app/planoai/config_providers.py
|
||||
- ../crates/target/wasm32-wasip1/release/llm_gateway.wasm:/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
- ../crates/target/wasm32-wasip1/release/prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
- ~/archgw_logs:/var/log/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue