Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -1,3 +1,4 @@
import json
import os
from jinja2 import Environment, FileSystemLoader
import yaml
@ -47,32 +48,27 @@ def validate_and_render_schema():
config_schema_yaml = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
endpoints = config_yaml.get("endpoints", {})
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
inferred_clusters[name] = endpoint_details
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", None)
if not name:
continue
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,
"port": 80, # default port
}
endpoints = config_yaml.get("endpoints", {})
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
if name in inferred_clusters:
print("updating cluster", endpoint_details)
inferred_clusters[name].update(endpoint_details)
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
else:
inferred_clusters[name] = endpoint_details
print("updated clusters", inferred_clusters)
raise Exception(
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
)
arch_llm_providers = config_yaml["llm_providers"]
arch_tracing = config_yaml.get("tracing", {})
@ -90,6 +86,7 @@ def validate_and_render_schema():
rendered = template.render(data)
print(ENVOY_CONFIG_FILE_RENDERED)
print(rendered)
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered)