add support for using custom upstream llm (#365)

This commit is contained in:
Adil Hafeez 2025-01-17 18:25:55 -08:00 committed by GitHub
parent 3fc21de60c
commit 07ef3149b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 263 additions and 52 deletions

View file

@ -43,19 +43,27 @@ properties:
properties:
name:
type: string
# this field is deprecated, use provider_interface instead
provider:
type: string
enum:
- openai
provider_interface:
type: string
enum:
- openai
- mistral
access_key:
type: string
model:
type: string
default:
type: boolean
endpoint:
type: string
additionalProperties: false
required:
- name
- provider
- access_key
- model
overrides:
type: object

View file

@ -125,15 +125,21 @@ static_resources:
- "*"
routes:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
- match:
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: {{ provider.name }}
exact: {{ llm_cluster_name }}
route:
auto_host_rewrite: true
cluster: {{ provider.provider }}
cluster: {{ llm_cluster_name }}
timeout: 60s
{% endfor %}
http_filters:
@ -237,16 +243,16 @@ static_resources:
domains:
- "*"
routes:
{% for internal_clustrer in ["arch_fc", "model_server"] %}
{% for internal_cluster in ["arch_fc", "model_server"] %}
- match:
prefix: "/"
headers:
- name: "x-arch-upstream"
string_match:
exact: {{ internal_clustrer }}
exact: {{ internal_cluster }}
route:
auto_host_rewrite: true
cluster: {{ internal_clustrer }}
cluster: {{ internal_cluster }}
timeout: 60s
{% endfor %}
@ -370,15 +376,21 @@ static_resources:
cluster: openai
timeout: 60s
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
- match:
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: {{ provider.name }}
exact: {{ llm_cluster_name }}
route:
auto_host_rewrite: true
cluster: {{ provider.provider }}
cluster: {{ llm_cluster_name }}
timeout: 60s
{% endfor %}
- match:
@ -538,6 +550,24 @@ static_resources:
tls_maximum_protocol_version: TLSv1_3
{% endif %}
{% endfor %}
{% for local_llm_provider in local_llms %}
- name: {{ local_llm_provider.name }}
connect_timeout: 5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ local_llm_provider.name }}
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: {{ local_llm_provider.endpoint }}
port_value: {{ local_llm_provider.port }}
hostname: {{ local_llm_provider.endpoint }}
{% endfor %}
- name: arch_internal
connect_timeout: 5s
type: LOGICAL_DNS

View file

@ -16,18 +16,6 @@ ARCH_CONFIG_SCHEMA_FILE = os.getenv(
)
def add_secret_key_to_llm_providers(config_yaml):
llm_providers = []
for llm_provider in config_yaml.get("llm_providers", []):
access_key_env_var = llm_provider.get("access_key", False)
access_key_value = os.getenv(access_key_env_var, False)
if access_key_env_var and access_key_value:
llm_provider["access_key"] = access_key_value
llm_providers.append(llm_provider)
config_yaml["llm_providers"] = llm_providers
return config_yaml
def validate_and_render_schema():
env = Environment(loader=FileSystemLoader("./"))
template = env.get_template("envoy.template.yaml")
@ -70,18 +58,42 @@ def validate_and_render_schema():
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
)
arch_llm_providers = config_yaml["llm_providers"]
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
for llm_provider in config_yaml["llm_providers"]:
provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if llm_provider.get("endpoint", None):
endpoint = llm_provider["endpoint"]
if len(endpoint.split(":")) > 1:
llm_provider["endpoint"] = endpoint.split(":")[0]
llm_provider["port"] = int(endpoint.split(":")[1])
llms_with_endpoint.append(llm_provider)
config_yaml["llm_providers"] = updated_llm_providers
arch_config_string = yaml.dump(config_yaml)
config_yaml["mode"] = "llm"
arch_llm_config_string = yaml.dump(config_yaml)
data = {
"arch_config": arch_config_string,
"arch_llm_config": arch_llm_config_string,
"arch_clusters": inferred_clusters,
"arch_llm_providers": arch_llm_providers,
"arch_llm_providers": config_yaml["llm_providers"],
"arch_tracing": arch_tracing,
"local_llms": llms_with_endpoint,
}
rendered = template.render(data)