mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 01:06:25 +02:00
better model names (#517)
This commit is contained in:
parent
4e2355965b
commit
a7fddf30f9
55 changed files with 979 additions and 483 deletions
|
|
@ -5,16 +5,16 @@ import yaml
|
|||
from jsonschema import validate
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
SUPPORTED_PROVIDERS = [
|
||||
"arch",
|
||||
"claude",
|
||||
"deepseek",
|
||||
"groq",
|
||||
"mistral",
|
||||
"openai",
|
||||
"gemini",
|
||||
]
|
||||
|
||||
|
||||
def get_endpoint_and_port(endpoint, protocol):
|
||||
|
|
@ -32,8 +32,22 @@ def get_endpoint_and_port(endpoint, protocol):
|
|||
|
||||
|
||||
def validate_and_render_schema():
|
||||
env = Environment(loader=FileSystemLoader("./"))
|
||||
template = env.get_template("envoy.template.yaml")
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ARCH_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
|
||||
)
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
env = Environment(loader=FileSystemLoader(os.getenv("TEMPLATE_ROOT", "./")))
|
||||
template = env.get_template(ENVOY_CONFIG_TEMPLATE_FILE)
|
||||
|
||||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
|
|
@ -82,6 +96,8 @@ def validate_and_render_schema():
|
|||
updated_llm_providers = []
|
||||
llm_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
if llm_provider.get("usage", None):
|
||||
llms_with_usage.append(llm_provider["name"])
|
||||
|
|
@ -89,10 +105,52 @@ def validate_and_render_schema():
|
|||
raise Exception(
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
)
|
||||
if llm_provider.get("name") is None:
|
||||
|
||||
model_name = llm_provider.get("model")
|
||||
if model_name in model_name_keys:
|
||||
raise Exception(
|
||||
f"llm_provider name is required, please provide name for llm_provider"
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
|
||||
)
|
||||
model_name_keys.add(model_name)
|
||||
if llm_provider.get("name") is None:
|
||||
llm_provider["name"] = model_name
|
||||
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
llm_provider.get("base_url", None) is None
|
||||
or llm_provider.get("provider_interface", None) is None
|
||||
):
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = llm_provider.get("provider_interface", None)
|
||||
elif llm_provider.get("provider_interface", None) is not None:
|
||||
raise Exception(
|
||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in llm_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
|
||||
llm_provider["model"] = model_id
|
||||
llm_provider["provider_interface"] = provider
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
provider = None
|
||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
||||
|
|
@ -105,21 +163,14 @@ def validate_and_render_schema():
|
|||
del llm_provider["provider"]
|
||||
updated_llm_providers.append(llm_provider)
|
||||
|
||||
if llm_provider.get("endpoint") and llm_provider.get("base_url"):
|
||||
raise Exception("Please provide either endpoint or base_url, not both")
|
||||
|
||||
if llm_provider.get("endpoint", None):
|
||||
endpoint = llm_provider["endpoint"]
|
||||
protocol = llm_provider.get("protocol", "http")
|
||||
llm_provider["endpoint"], llm_provider["port"] = get_endpoint_and_port(
|
||||
endpoint, protocol
|
||||
)
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
elif llm_provider.get("base_url", None):
|
||||
if llm_provider.get("base_url", None):
|
||||
base_url = llm_provider["base_url"]
|
||||
urlparse_result = urlparse(base_url)
|
||||
if llm_provider.get("port"):
|
||||
raise Exception("Please provider port in base_url")
|
||||
url_path = urlparse_result.path
|
||||
if url_path and url_path != "/":
|
||||
raise Exception(
|
||||
f"Please provide base_url without path, got {base_url}. Use base_url like 'http://example.com' instead of 'http://example.com/path'."
|
||||
)
|
||||
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
||||
"http",
|
||||
"https",
|
||||
|
|
@ -140,7 +191,7 @@ def validate_and_render_schema():
|
|||
llm_provider["protocol"] = protocol
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
|
||||
if len(llms_with_usage) > 0:
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None)
|
||||
if routing_llm_provider and routing_llm_provider not in llm_provider_name_set:
|
||||
raise Exception(
|
||||
|
|
@ -198,6 +249,7 @@ def validate_and_render_schema():
|
|||
agent_orchestrator = list(endpoints.keys())[0]
|
||||
|
||||
print("agent_orchestrator: ", agent_orchestrator)
|
||||
|
||||
data = {
|
||||
"prompt_gateway_listener": prompt_gateway_listener,
|
||||
"llm_gateway_listener": llm_gateway_listener,
|
||||
|
|
@ -216,6 +268,9 @@ def validate_and_render_schema():
|
|||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
|
||||
with open(ARCH_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(arch_config_string)
|
||||
|
||||
|
||||
def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
||||
with open(arch_config_file, "r") as file:
|
||||
|
|
@ -231,7 +286,7 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
|||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e.message}"
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue