From a91fbdbf1c2be9748e1f1ec50c88edaac0759057 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 18 Sep 2024 20:03:26 -0700 Subject: [PATCH] Add ability to define clusters in config (#61) --- .pre-commit-config.yaml | 1 + config_generator/Dockerfile | 4 +-- config_generator/config_generator.py | 31 ++++++++++++++++- config_generator/requirements.txt | 2 ++ demos/function_calling/bolt_config.yaml | 4 +++ envoyfilter/envoy.template.yaml | 45 +++++++++---------------- gateway.code-workspace | 4 +++ 7 files changed, 59 insertions(+), 32 deletions(-) create mode 100644 config_generator/requirements.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6548745..5b34693d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,7 @@ repos: rev: v4.6.0 hooks: - id: check-yaml + exclude: envoyfilter/envoy.template.yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: local diff --git a/config_generator/Dockerfile b/config_generator/Dockerfile index 36c836db..00ff5b93 100644 --- a/config_generator/Dockerfile +++ b/config_generator/Dockerfile @@ -1,9 +1,9 @@ FROM python:3-slim as config-generator WORKDIR /usr/src/app -RUN pip install jinja2 +COPY config_generator/requirements.txt . +RUN pip install -r requirements.txt COPY config_generator/config_generator.py . COPY envoyfilter/envoy.template.yaml . COPY envoyfilter/katanemo-config.yaml . -# RUN python config_generator.py > envoy.yaml CMD ["python", "config_generator.py"] diff --git a/config_generator/config_generator.py b/config_generator/config_generator.py index 2693f834..1c1d06ac 100644 --- a/config_generator/config_generator.py +++ b/config_generator/config_generator.py @@ -1,5 +1,6 @@ import os from jinja2 import Environment, FileSystemLoader +import yaml ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml') BOLT_CONFIG_FILE = os.getenv('BOLT_CONFIG_FILE', 'bolt_config.yaml') @@ -11,8 +12,36 @@ template = env.get_template('envoy.template.yaml') with open(BOLT_CONFIG_FILE, 'r') as file: katanemo_config = file.read() +config_yaml = yaml.safe_load(katanemo_config) + +inferred_clusters = {} + +for prompt_target in config_yaml["prompt_targets"]: + cluster = prompt_target.get("endpoint", {}).get("cluster", "") + if cluster not in inferred_clusters: + inferred_clusters[cluster] = { + "name": cluster, + "address": cluster, + "port": 80, # default port + } + +print(inferred_clusters) + +clusters = config_yaml.get("clusters", {}) + +# override the inferred clusters with the ones defined in the config +for name, cluster in clusters.items(): + if name in inferred_clusters: + print("updating cluster", cluster) + inferred_clusters[name].update(cluster) + else: + inferred_clusters[name] = cluster + +print("updated clusters", inferred_clusters) + data = { - 'katanemo_config': katanemo_config + 'katanemo_config': katanemo_config, + 'arch_clusters': inferred_clusters } rendered = template.render(data) diff --git a/config_generator/requirements.txt b/config_generator/requirements.txt new file mode 100644 index 00000000..4e859bb8 --- /dev/null +++ b/config_generator/requirements.txt @@ -0,0 +1,2 @@ +jinja2 +pyyaml diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index 1c5b1a56..684ac5ba 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -51,3 +51,7 @@ prompt_targets: system_prompt: | You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - Use policy number to retrieve insurance claim details + +clusters: + weatherhost: + address: model_server diff --git a/envoyfilter/envoy.template.yaml b/envoyfilter/envoy.template.yaml index bd79b6c0..fb653021 100644 --- a/envoyfilter/envoy.template.yaml +++ b/envoyfilter/envoy.template.yaml @@ -122,7 +122,6 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 - - name: model_server connect_timeout: 5s type: STRICT_DNS @@ -137,34 +136,6 @@ static_resources: address: model_server port_value: 80 hostname: "model_server" - - name: weatherhost - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: weatherhost - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: model_server - port_value: 80 - hostname: "model_server" - - name: nerhost - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: nerhost - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: model_server - port_value: 80 - hostname: "model_server" - name: mistral_7b_instruct connect_timeout: 5s type: STRICT_DNS @@ -193,3 +164,19 @@ static_resources: address: function_resolver port_value: 80 hostname: "bolt_fc_1b" +{% for _, cluster in arch_clusters.items() %} + - name: {{ cluster.name }} + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: {{ cluster.name }} + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: {{ cluster.address }} + port_value: {{ cluster.port }} + hostname: {{ cluster.address }} +{% endfor %} diff --git a/gateway.code-workspace b/gateway.code-workspace index a6dac5c6..90f5c25b 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -24,6 +24,10 @@ "name": "open-message-format", "path": "open-message-format" }, + { + "name": "config_generator", + "path": "config_generator" + }, { "name": "demos/function_calling", "path": "./demos/function_calling",