2024-12-20 13:25:01 -08:00
|
|
|
import json
|
2024-10-03 18:21:27 -07:00
|
|
|
import os
|
|
|
|
|
from jinja2 import Environment, FileSystemLoader
|
|
|
|
|
import yaml
|
|
|
|
|
from jsonschema import validate
|
2025-03-05 17:20:04 -08:00
|
|
|
from urllib.parse import urlparse
|
2024-10-03 18:21:27 -07:00
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
|
|
|
|
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
|
|
|
|
)
|
2024-11-15 10:44:01 -08:00
|
|
|
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
2024-10-09 11:25:07 -07:00
|
|
|
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
|
|
|
|
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
|
|
|
|
)
|
|
|
|
|
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
|
|
|
|
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-02-03 14:51:59 -08:00
|
|
|
def get_endpoint_and_port(endpoint, protocol):
|
|
|
|
|
endpoint_tokens = endpoint.split(":")
|
|
|
|
|
if len(endpoint_tokens) > 1:
|
|
|
|
|
endpoint = endpoint_tokens[0]
|
|
|
|
|
port = int(endpoint_tokens[1])
|
|
|
|
|
return endpoint, port
|
|
|
|
|
else:
|
|
|
|
|
if protocol == "http":
|
|
|
|
|
port = 80
|
|
|
|
|
else:
|
|
|
|
|
port = 443
|
|
|
|
|
return endpoint, port
|
|
|
|
|
|
|
|
|
|
|
2024-10-03 18:21:27 -07:00
|
|
|
def validate_and_render_schema():
|
2024-10-09 11:25:07 -07:00
|
|
|
env = Environment(loader=FileSystemLoader("./"))
|
|
|
|
|
template = env.get_template("envoy.template.yaml")
|
2024-10-03 18:21:27 -07:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
|
|
|
|
except Exception as e:
|
2024-10-30 17:54:51 -07:00
|
|
|
print(str(e))
|
2024-10-09 11:25:07 -07:00
|
|
|
exit(1) # validate_prompt_config failed. Exit
|
2024-10-03 18:21:27 -07:00
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
with open(ARCH_CONFIG_FILE, "r") as file:
|
2024-10-03 18:21:27 -07:00
|
|
|
arch_config = file.read()
|
|
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
|
2024-10-03 18:21:27 -07:00
|
|
|
arch_config_schema = file.read()
|
|
|
|
|
|
|
|
|
|
config_yaml = yaml.safe_load(arch_config)
|
2025-03-19 15:21:34 -07:00
|
|
|
_ = yaml.safe_load(arch_config_schema)
|
2024-10-03 18:21:27 -07:00
|
|
|
inferred_clusters = {}
|
|
|
|
|
|
2024-12-20 13:25:01 -08:00
|
|
|
endpoints = config_yaml.get("endpoints", {})
|
|
|
|
|
|
|
|
|
|
# override the inferred clusters with the ones defined in the config
|
|
|
|
|
for name, endpoint_details in endpoints.items():
|
|
|
|
|
inferred_clusters[name] = endpoint_details
|
|
|
|
|
endpoint = inferred_clusters[name]["endpoint"]
|
2025-02-03 14:51:59 -08:00
|
|
|
protocol = inferred_clusters[name].get("protocol", "http")
|
|
|
|
|
(
|
|
|
|
|
inferred_clusters[name]["endpoint"],
|
|
|
|
|
inferred_clusters[name]["port"],
|
|
|
|
|
) = get_endpoint_and_port(endpoint, protocol)
|
2024-12-20 13:25:01 -08:00
|
|
|
|
|
|
|
|
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
|
|
|
|
|
2024-10-28 20:05:06 -04:00
|
|
|
if "prompt_targets" in config_yaml:
|
|
|
|
|
for prompt_target in config_yaml["prompt_targets"]:
|
2024-12-06 14:37:33 -08:00
|
|
|
name = prompt_target.get("endpoint", {}).get("name", None)
|
|
|
|
|
if not name:
|
|
|
|
|
continue
|
2024-10-28 20:05:06 -04:00
|
|
|
if name not in inferred_clusters:
|
2024-12-20 13:25:01 -08:00
|
|
|
raise Exception(
|
|
|
|
|
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
|
|
|
|
)
|
2024-10-03 18:21:27 -07:00
|
|
|
|
2024-10-08 16:24:08 -07:00
|
|
|
arch_tracing = config_yaml.get("tracing", {})
|
2025-01-17 18:25:55 -08:00
|
|
|
|
|
|
|
|
llms_with_endpoint = []
|
|
|
|
|
|
|
|
|
|
updated_llm_providers = []
|
2025-05-23 00:51:53 -07:00
|
|
|
llm_provider_name_set = set()
|
2025-05-23 08:46:12 -07:00
|
|
|
llms_with_usage = []
|
2025-01-17 18:25:55 -08:00
|
|
|
for llm_provider in config_yaml["llm_providers"]:
|
2025-05-23 08:46:12 -07:00
|
|
|
if llm_provider.get("usage", None):
|
|
|
|
|
llms_with_usage.append(llm_provider["name"])
|
2025-05-23 00:51:53 -07:00
|
|
|
if llm_provider.get("name") in llm_provider_name_set:
|
|
|
|
|
raise Exception(
|
|
|
|
|
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
|
|
|
|
)
|
|
|
|
|
if llm_provider.get("name") is None:
|
|
|
|
|
raise Exception(
|
|
|
|
|
f"llm_provider name is required, please provide name for llm_provider"
|
|
|
|
|
)
|
|
|
|
|
llm_provider_name_set.add(llm_provider.get("name"))
|
2025-01-17 18:25:55 -08:00
|
|
|
provider = None
|
|
|
|
|
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Please provide either provider or provider_interface, not both"
|
|
|
|
|
)
|
|
|
|
|
if llm_provider.get("provider"):
|
|
|
|
|
provider = llm_provider["provider"]
|
|
|
|
|
llm_provider["provider_interface"] = provider
|
|
|
|
|
del llm_provider["provider"]
|
|
|
|
|
updated_llm_providers.append(llm_provider)
|
|
|
|
|
|
2025-03-05 17:20:04 -08:00
|
|
|
if llm_provider.get("endpoint") and llm_provider.get("base_url"):
|
|
|
|
|
raise Exception("Please provide either endpoint or base_url, not both")
|
|
|
|
|
|
2025-01-17 18:25:55 -08:00
|
|
|
if llm_provider.get("endpoint", None):
|
|
|
|
|
endpoint = llm_provider["endpoint"]
|
2025-02-03 14:51:59 -08:00
|
|
|
protocol = llm_provider.get("protocol", "http")
|
|
|
|
|
llm_provider["endpoint"], llm_provider["port"] = get_endpoint_and_port(
|
|
|
|
|
endpoint, protocol
|
|
|
|
|
)
|
2025-01-17 18:25:55 -08:00
|
|
|
llms_with_endpoint.append(llm_provider)
|
2025-03-05 17:20:04 -08:00
|
|
|
elif llm_provider.get("base_url", None):
|
|
|
|
|
base_url = llm_provider["base_url"]
|
|
|
|
|
urlparse_result = urlparse(base_url)
|
|
|
|
|
if llm_provider.get("port"):
|
|
|
|
|
raise Exception("Please provider port in base_url")
|
|
|
|
|
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
|
|
|
|
"http",
|
|
|
|
|
"https",
|
|
|
|
|
]:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Please provide a valid URL with scheme (http/https) in base_url"
|
|
|
|
|
)
|
|
|
|
|
protocol = urlparse_result.scheme
|
|
|
|
|
port = urlparse_result.port
|
|
|
|
|
if port is None:
|
|
|
|
|
if protocol == "http":
|
|
|
|
|
port = 80
|
|
|
|
|
else:
|
|
|
|
|
port = 443
|
|
|
|
|
endpoint = urlparse_result.hostname
|
|
|
|
|
llm_provider["endpoint"] = endpoint
|
|
|
|
|
llm_provider["port"] = port
|
|
|
|
|
llm_provider["protocol"] = protocol
|
|
|
|
|
llms_with_endpoint.append(llm_provider)
|
2025-05-23 08:46:12 -07:00
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
len(llms_with_usage) > 0
|
|
|
|
|
and config_yaml.get("routing", {}).get("model", None) == None
|
|
|
|
|
):
|
|
|
|
|
llms_with_usage_names = ", ".join(llms_with_usage)
|
|
|
|
|
raise Exception(
|
|
|
|
|
f"LLMs with usage found ({llms_with_usage_names}), please provide model in routing section in your arch_config.yaml file"
|
|
|
|
|
)
|
2025-01-17 18:25:55 -08:00
|
|
|
|
|
|
|
|
config_yaml["llm_providers"] = updated_llm_providers
|
|
|
|
|
|
2024-10-03 18:21:27 -07:00
|
|
|
arch_config_string = yaml.dump(config_yaml)
|
2024-10-09 15:47:32 -07:00
|
|
|
arch_llm_config_string = yaml.dump(config_yaml)
|
2024-10-03 18:21:27 -07:00
|
|
|
|
2025-02-14 19:28:10 -08:00
|
|
|
prompt_gateway_listener = config_yaml.get("listeners", {}).get(
|
|
|
|
|
"ingress_traffic", {}
|
|
|
|
|
)
|
|
|
|
|
if prompt_gateway_listener.get("port") == None:
|
|
|
|
|
prompt_gateway_listener["port"] = 10000 # default port for prompt gateway
|
|
|
|
|
if prompt_gateway_listener.get("address") == None:
|
|
|
|
|
prompt_gateway_listener["address"] = "127.0.0.1"
|
|
|
|
|
if prompt_gateway_listener.get("timeout") == None:
|
|
|
|
|
prompt_gateway_listener["timeout"] = "10s"
|
|
|
|
|
|
|
|
|
|
llm_gateway_listener = config_yaml.get("listeners", {}).get("egress_traffic", {})
|
|
|
|
|
if llm_gateway_listener.get("port") == None:
|
|
|
|
|
llm_gateway_listener["port"] = 12000 # default port for llm gateway
|
|
|
|
|
if llm_gateway_listener.get("address") == None:
|
|
|
|
|
llm_gateway_listener["address"] = "127.0.0.1"
|
|
|
|
|
if llm_gateway_listener.get("timeout") == None:
|
|
|
|
|
llm_gateway_listener["timeout"] = "10s"
|
|
|
|
|
|
2025-03-19 15:21:34 -07:00
|
|
|
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
|
|
|
|
|
"use_agent_orchestrator", False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
agent_orchestrator = None
|
|
|
|
|
if use_agent_orchestrator:
|
|
|
|
|
print("Using agent orchestrator")
|
|
|
|
|
|
|
|
|
|
if len(endpoints) == 0:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Please provide agent orchestrator in the endpoints section in your arch_config.yaml file"
|
|
|
|
|
)
|
|
|
|
|
elif len(endpoints) > 1:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
agent_orchestrator = list(endpoints.keys())[0]
|
|
|
|
|
|
|
|
|
|
print("agent_orchestrator: ", agent_orchestrator)
|
2024-10-03 18:21:27 -07:00
|
|
|
data = {
|
2025-02-14 19:28:10 -08:00
|
|
|
"prompt_gateway_listener": prompt_gateway_listener,
|
|
|
|
|
"llm_gateway_listener": llm_gateway_listener,
|
2024-10-09 11:25:07 -07:00
|
|
|
"arch_config": arch_config_string,
|
2024-10-09 15:47:32 -07:00
|
|
|
"arch_llm_config": arch_llm_config_string,
|
2024-10-09 11:25:07 -07:00
|
|
|
"arch_clusters": inferred_clusters,
|
2025-01-17 18:25:55 -08:00
|
|
|
"arch_llm_providers": config_yaml["llm_providers"],
|
2024-10-09 11:25:07 -07:00
|
|
|
"arch_tracing": arch_tracing,
|
2025-01-17 18:25:55 -08:00
|
|
|
"local_llms": llms_with_endpoint,
|
2025-03-19 15:21:34 -07:00
|
|
|
"agent_orchestrator": agent_orchestrator,
|
2024-10-03 18:21:27 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rendered = template.render(data)
|
|
|
|
|
print(ENVOY_CONFIG_FILE_RENDERED)
|
2024-12-20 13:25:01 -08:00
|
|
|
print(rendered)
|
2024-10-09 11:25:07 -07:00
|
|
|
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
2024-10-03 18:21:27 -07:00
|
|
|
file.write(rendered)
|
|
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
|
2024-10-03 18:21:27 -07:00
|
|
|
def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
2024-10-09 11:25:07 -07:00
|
|
|
with open(arch_config_file, "r") as file:
|
2024-10-03 18:21:27 -07:00
|
|
|
arch_config = file.read()
|
|
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
with open(arch_config_schema_file, "r") as file:
|
2024-10-03 18:21:27 -07:00
|
|
|
arch_config_schema = file.read()
|
|
|
|
|
|
|
|
|
|
config_yaml = yaml.safe_load(arch_config)
|
|
|
|
|
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
validate(config_yaml, config_schema_yaml)
|
|
|
|
|
except Exception as e:
|
2024-10-09 11:25:07 -07:00
|
|
|
print(
|
2024-11-26 13:13:02 -08:00
|
|
|
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e.message}"
|
2024-10-09 11:25:07 -07:00
|
|
|
)
|
2024-10-03 18:21:27 -07:00
|
|
|
raise e
|
|
|
|
|
|
2024-10-09 11:25:07 -07:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-10-03 18:21:27 -07:00
|
|
|
validate_and_render_schema()
|