Add ability to define clusters in config (#61)

This commit is contained in:
Adil Hafeez 2024-09-18 20:03:26 -07:00 committed by GitHub
parent 215d276acf
commit a91fbdbf1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 59 additions and 32 deletions

View file

@ -3,6 +3,7 @@ repos:
rev: v4.6.0
hooks:
- id: check-yaml
exclude: envoyfilter/envoy.template.yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: local

View file

@ -1,9 +1,9 @@
FROM python:3-slim as config-generator
WORKDIR /usr/src/app
RUN pip install jinja2
COPY config_generator/requirements.txt .
RUN pip install -r requirements.txt
COPY config_generator/config_generator.py .
COPY envoyfilter/envoy.template.yaml .
COPY envoyfilter/katanemo-config.yaml .
# RUN python config_generator.py > envoy.yaml
CMD ["python", "config_generator.py"]

View file

@ -1,5 +1,6 @@
import os
from jinja2 import Environment, FileSystemLoader
import yaml
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
BOLT_CONFIG_FILE = os.getenv('BOLT_CONFIG_FILE', 'bolt_config.yaml')
@ -11,8 +12,36 @@ template = env.get_template('envoy.template.yaml')
with open(BOLT_CONFIG_FILE, 'r') as file:
katanemo_config = file.read()
config_yaml = yaml.safe_load(katanemo_config)
inferred_clusters = {}
for prompt_target in config_yaml["prompt_targets"]:
cluster = prompt_target.get("endpoint", {}).get("cluster", "")
if cluster not in inferred_clusters:
inferred_clusters[cluster] = {
"name": cluster,
"address": cluster,
"port": 80, # default port
}
print(inferred_clusters)
clusters = config_yaml.get("clusters", {})
# override the inferred clusters with the ones defined in the config
for name, cluster in clusters.items():
if name in inferred_clusters:
print("updating cluster", cluster)
inferred_clusters[name].update(cluster)
else:
inferred_clusters[name] = cluster
print("updated clusters", inferred_clusters)
data = {
'katanemo_config': katanemo_config
'katanemo_config': katanemo_config,
'arch_clusters': inferred_clusters
}
rendered = template.render(data)

View file

@ -0,0 +1,2 @@
jinja2
pyyaml

View file

@ -51,3 +51,7 @@ prompt_targets:
system_prompt: |
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
- Use policy number to retrieve insurance claim details
clusters:
weatherhost:
address: model_server

View file

@ -122,7 +122,6 @@ static_resources:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: model_server
connect_timeout: 5s
type: STRICT_DNS
@ -137,34 +136,6 @@ static_resources:
address: model_server
port_value: 80
hostname: "model_server"
- name: weatherhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: weatherhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: model_server
port_value: 80
hostname: "model_server"
- name: nerhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: nerhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: model_server
port_value: 80
hostname: "model_server"
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
@ -193,3 +164,19 @@ static_resources:
address: function_resolver
port_value: 80
hostname: "bolt_fc_1b"
{% for _, cluster in arch_clusters.items() %}
- name: {{ cluster.name }}
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ cluster.name }}
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: {{ cluster.address }}
port_value: {{ cluster.port }}
hostname: {{ cluster.address }}
{% endfor %}

View file

@ -24,6 +24,10 @@
"name": "open-message-format",
"path": "open-message-format"
},
{
"name": "config_generator",
"path": "config_generator"
},
{
"name": "demos/function_calling",
"path": "./demos/function_calling",