mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
add support for model_providers
This commit is contained in:
parent
2cebc0c85f
commit
92a8782332
10 changed files with 183 additions and 93 deletions
|
|
@ -15,9 +15,30 @@ properties:
|
||||||
items:
|
items:
|
||||||
type: object
|
type: object
|
||||||
listeners:
|
listeners:
|
||||||
anyOf:
|
oneOf:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
additionalProperties: false
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
port:
|
||||||
|
type: integer
|
||||||
|
address:
|
||||||
|
type: string
|
||||||
|
timeout:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- model_listener
|
||||||
|
- prompt_listener
|
||||||
|
- agent_listener
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- name
|
||||||
|
- type: object # deprecated legacy format, use list format instead
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
ingress_traffic:
|
ingress_traffic:
|
||||||
|
|
@ -69,7 +90,52 @@ properties:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- endpoint
|
- endpoint
|
||||||
llm_providers:
|
|
||||||
|
model_providers:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
access_key:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
default:
|
||||||
|
type: boolean
|
||||||
|
base_url:
|
||||||
|
type: string
|
||||||
|
http_host:
|
||||||
|
type: string
|
||||||
|
provider_interface:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- arch
|
||||||
|
- claude
|
||||||
|
- deepseek
|
||||||
|
- groq
|
||||||
|
- mistral
|
||||||
|
- openai
|
||||||
|
- gemini
|
||||||
|
routing_preferences:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
- description
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
|
||||||
|
llm_providers: # deprecated for legacy support, use model_providers instead
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from cli.utils import convert_legacy_llm_providers
|
from cli.utils import convert_legacy_listeners
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
import yaml
|
import yaml
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
|
|
@ -71,8 +71,17 @@ def validate_and_render_schema():
|
||||||
_ = yaml.safe_load(arch_config_schema)
|
_ = yaml.safe_load(arch_config_schema)
|
||||||
inferred_clusters = {}
|
inferred_clusters = {}
|
||||||
|
|
||||||
listeners, llm_gateway, prompt_gateway = convert_legacy_llm_providers(
|
# Convert legacy llm_providers to model_providers
|
||||||
config_yaml.get("listeners"), config_yaml.get("llm_providers")
|
if "llm_providers" in config_yaml:
|
||||||
|
if "model_providers" in config_yaml:
|
||||||
|
raise Exception(
|
||||||
|
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
|
||||||
|
)
|
||||||
|
config_yaml["model_providers"] = config_yaml["llm_providers"]
|
||||||
|
del config_yaml["llm_providers"]
|
||||||
|
|
||||||
|
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||||
|
config_yaml.get("listeners"), config_yaml.get("model_providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
config_yaml["listeners"] = listeners
|
config_yaml["listeners"] = listeners
|
||||||
|
|
@ -130,36 +139,39 @@ def validate_and_render_schema():
|
||||||
arch_tracing = config_yaml.get("tracing", {})
|
arch_tracing = config_yaml.get("tracing", {})
|
||||||
|
|
||||||
llms_with_endpoint = []
|
llms_with_endpoint = []
|
||||||
updated_llm_providers = []
|
updated_model_providers = []
|
||||||
llm_provider_name_set = set()
|
model_provider_name_set = set()
|
||||||
llms_with_usage = []
|
llms_with_usage = []
|
||||||
model_name_keys = set()
|
model_name_keys = set()
|
||||||
model_usage_name_keys = set()
|
model_usage_name_keys = set()
|
||||||
|
|
||||||
for listener in listeners:
|
for listener in listeners:
|
||||||
if listener.get("llm_providers") is None or listener.get("llm_providers") == []:
|
if (
|
||||||
|
listener.get("model_providers") is None
|
||||||
|
or listener.get("model_providers") == []
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
print("Processing listener with llm_providers: ", listener)
|
print("Processing listener with model_providers: ", listener)
|
||||||
name = listener.get("name", None)
|
name = listener.get("name", None)
|
||||||
|
|
||||||
for llm_provider in listener.get("llm_providers", []):
|
for model_provider in listener.get("model_providers", []):
|
||||||
if llm_provider.get("usage", None):
|
if model_provider.get("usage", None):
|
||||||
llms_with_usage.append(llm_provider["name"])
|
llms_with_usage.append(model_provider["name"])
|
||||||
if llm_provider.get("name") in llm_provider_name_set:
|
if model_provider.get("name") in model_provider_name_set:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = llm_provider.get("model")
|
model_name = model_provider.get("model")
|
||||||
if model_name in model_name_keys:
|
if model_name in model_name_keys:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
|
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
|
||||||
)
|
)
|
||||||
model_name_keys.add(model_name)
|
model_name_keys.add(model_name)
|
||||||
if llm_provider.get("name") is None:
|
if model_provider.get("name") is None:
|
||||||
llm_provider["name"] = model_name
|
model_provider["name"] = model_name
|
||||||
|
|
||||||
llm_provider_name_set.add(llm_provider.get("name"))
|
model_provider_name_set.add(model_provider.get("name"))
|
||||||
|
|
||||||
model_name_tokens = model_name.split("/")
|
model_name_tokens = model_name.split("/")
|
||||||
if len(model_name_tokens) < 2:
|
if len(model_name_tokens) < 2:
|
||||||
|
|
@ -171,7 +183,7 @@ def validate_and_render_schema():
|
||||||
# Validate azure_openai and ollama provider requires base_url
|
# Validate azure_openai and ollama provider requires base_url
|
||||||
if (
|
if (
|
||||||
provider == "azure_openai" or provider == "ollama"
|
provider == "azure_openai" or provider == "ollama"
|
||||||
) and llm_provider.get("base_url") is None:
|
) and model_provider.get("base_url") is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||||
)
|
)
|
||||||
|
|
@ -179,46 +191,48 @@ def validate_and_render_schema():
|
||||||
model_id = "/".join(model_name_tokens[1:])
|
model_id = "/".join(model_name_tokens[1:])
|
||||||
if provider not in SUPPORTED_PROVIDERS:
|
if provider not in SUPPORTED_PROVIDERS:
|
||||||
if (
|
if (
|
||||||
llm_provider.get("base_url", None) is None
|
model_provider.get("base_url", None) is None
|
||||||
or llm_provider.get("provider_interface", None) is None
|
or model_provider.get("provider_interface", None) is None
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||||
)
|
)
|
||||||
provider = llm_provider.get("provider_interface", None)
|
provider = model_provider.get("provider_interface", None)
|
||||||
elif llm_provider.get("provider_interface", None) is not None:
|
elif model_provider.get("provider_interface", None) is not None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_id in model_name_keys:
|
if model_id in model_name_keys:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider"
|
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||||
)
|
)
|
||||||
model_name_keys.add(model_id)
|
model_name_keys.add(model_id)
|
||||||
|
|
||||||
for routing_preference in llm_provider.get("routing_preferences", []):
|
for routing_preference in model_provider.get("routing_preferences", []):
|
||||||
if routing_preference.get("name") in model_usage_name_keys:
|
if routing_preference.get("name") in model_usage_name_keys:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
||||||
)
|
)
|
||||||
model_usage_name_keys.add(routing_preference.get("name"))
|
model_usage_name_keys.add(routing_preference.get("name"))
|
||||||
|
|
||||||
llm_provider["model"] = model_id
|
model_provider["model"] = model_id
|
||||||
llm_provider["provider_interface"] = provider
|
model_provider["provider_interface"] = provider
|
||||||
llm_provider_name_set.add(llm_provider.get("name"))
|
model_provider_name_set.add(model_provider.get("name"))
|
||||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
if model_provider.get("provider") and model_provider.get(
|
||||||
|
"provider_interface"
|
||||||
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Please provide either provider or provider_interface, not both"
|
"Please provide either provider or provider_interface, not both"
|
||||||
)
|
)
|
||||||
if llm_provider.get("provider"):
|
if model_provider.get("provider"):
|
||||||
provider = llm_provider["provider"]
|
provider = model_provider["provider"]
|
||||||
llm_provider["provider_interface"] = provider
|
model_provider["provider_interface"] = provider
|
||||||
del llm_provider["provider"]
|
del model_provider["provider"]
|
||||||
updated_llm_providers.append(llm_provider)
|
updated_model_providers.append(model_provider)
|
||||||
|
|
||||||
if llm_provider.get("base_url", None):
|
if model_provider.get("base_url", None):
|
||||||
base_url = llm_provider["base_url"]
|
base_url = model_provider["base_url"]
|
||||||
urlparse_result = urlparse(base_url)
|
urlparse_result = urlparse(base_url)
|
||||||
url_path = urlparse_result.path
|
url_path = urlparse_result.path
|
||||||
if url_path and url_path != "/":
|
if url_path and url_path != "/":
|
||||||
|
|
@ -240,22 +254,30 @@ def validate_and_render_schema():
|
||||||
else:
|
else:
|
||||||
port = 443
|
port = 443
|
||||||
endpoint = urlparse_result.hostname
|
endpoint = urlparse_result.hostname
|
||||||
llm_provider["endpoint"] = endpoint
|
model_provider["endpoint"] = endpoint
|
||||||
llm_provider["port"] = port
|
model_provider["port"] = port
|
||||||
llm_provider["protocol"] = protocol
|
model_provider["protocol"] = protocol
|
||||||
llm_provider["cluster_name"] = (
|
model_provider["cluster_name"] = (
|
||||||
provider + "_" + endpoint
|
provider + "_" + endpoint
|
||||||
) # make name unique by appending endpoint
|
) # make name unique by appending endpoint
|
||||||
llms_with_endpoint.append(llm_provider)
|
llms_with_endpoint.append(model_provider)
|
||||||
|
|
||||||
if len(model_usage_name_keys) > 0:
|
if len(model_usage_name_keys) > 0:
|
||||||
routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None)
|
routing_model_provider = config_yaml.get("routing", {}).get(
|
||||||
if routing_llm_provider and routing_llm_provider not in llm_provider_name_set:
|
"model_provider", None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
routing_model_provider
|
||||||
|
and routing_model_provider not in model_provider_name_set
|
||||||
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Routing llm_provider {routing_llm_provider} is not defined in llm_providers"
|
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
|
||||||
)
|
)
|
||||||
if routing_llm_provider is None and "arch-router" not in llm_provider_name_set:
|
if (
|
||||||
updated_llm_providers.append(
|
routing_model_provider is None
|
||||||
|
and "arch-router" not in model_provider_name_set
|
||||||
|
):
|
||||||
|
updated_model_providers.append(
|
||||||
{
|
{
|
||||||
"name": "arch-router",
|
"name": "arch-router",
|
||||||
"provider_interface": "arch",
|
"provider_interface": "arch",
|
||||||
|
|
@ -263,19 +285,19 @@ def validate_and_render_schema():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_llm_providers = []
|
updated_model_providers = []
|
||||||
for listener in listeners:
|
for listener in listeners:
|
||||||
print("Processing listener: ", listener)
|
print("Processing listener: ", listener)
|
||||||
llm_providers = listener.get("llm_providers", None)
|
model_providers = listener.get("model_providers", None)
|
||||||
if llm_providers is not None and llm_providers != []:
|
if model_providers is not None and model_providers != []:
|
||||||
print("processing egress traffic listener")
|
print("processing egress traffic listener")
|
||||||
print("updated_llm_providers: ", updated_llm_providers)
|
print("updated_model_providers: ", updated_model_providers)
|
||||||
if updated_llm_providers is not None and updated_llm_providers != []:
|
if updated_model_providers is not None and updated_model_providers != []:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Please provide llm_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with llm_providers"
|
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||||
)
|
)
|
||||||
updated_llm_providers = deepcopy(llm_providers)
|
updated_model_providers = deepcopy(model_providers)
|
||||||
config_yaml["llm_providers"] = updated_llm_providers
|
config_yaml["model_providers"] = updated_model_providers
|
||||||
|
|
||||||
# Validate model aliases if present
|
# Validate model aliases if present
|
||||||
if "model_aliases" in config_yaml:
|
if "model_aliases" in config_yaml:
|
||||||
|
|
@ -317,7 +339,7 @@ def validate_and_render_schema():
|
||||||
"arch_config": arch_config_string,
|
"arch_config": arch_config_string,
|
||||||
"arch_llm_config": arch_llm_config_string,
|
"arch_llm_config": arch_llm_config_string,
|
||||||
"arch_clusters": inferred_clusters,
|
"arch_clusters": inferred_clusters,
|
||||||
"arch_llm_providers": updated_llm_providers,
|
"arch_model_providers": updated_model_providers,
|
||||||
"arch_tracing": arch_tracing,
|
"arch_tracing": arch_tracing,
|
||||||
"local_llms": llms_with_endpoint,
|
"local_llms": llms_with_endpoint,
|
||||||
"agent_orchestrator": agent_orchestrator,
|
"agent_orchestrator": agent_orchestrator,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import time
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from cli.utils import convert_legacy_llm_providers, getLogger
|
from cli.utils import convert_legacy_listeners, getLogger
|
||||||
from cli.consts import (
|
from cli.consts import (
|
||||||
ARCHGW_DOCKER_IMAGE,
|
ARCHGW_DOCKER_IMAGE,
|
||||||
ARCHGW_DOCKER_NAME,
|
ARCHGW_DOCKER_NAME,
|
||||||
|
|
@ -37,7 +37,7 @@ def _get_gateway_ports(arch_config_file: str) -> list[int]:
|
||||||
|
|
||||||
print("arch config dict json string: ", json.dumps(arch_config_dict))
|
print("arch config dict json string: ", json.dumps(arch_config_dict))
|
||||||
|
|
||||||
listeners, _, _ = convert_legacy_llm_providers(
|
listeners, _, _ = convert_legacy_listeners(
|
||||||
arch_config_dict.get("listeners"), arch_config_dict.get("llm_providers")
|
arch_config_dict.get("listeners"), arch_config_dict.get("llm_providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,20 +37,22 @@ def has_ingress_listener(arch_config_file):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def convert_legacy_llm_providers(
|
def convert_legacy_listeners(
|
||||||
listeners: dict | list, llm_providers: list | None
|
listeners: dict | list, model_providers: list | None
|
||||||
) -> tuple[list, dict | None, dict | None]:
|
) -> tuple[list, dict | None, dict | None]:
|
||||||
llm_gateway_listener = {
|
llm_gateway_listener = {
|
||||||
"name": "egress_traffic",
|
"name": "egress_traffic",
|
||||||
|
"type": "model_listener",
|
||||||
"port": 12000,
|
"port": 12000,
|
||||||
"address": "0.0.0.0",
|
"address": "0.0.0.0",
|
||||||
"timeout": "30s",
|
"timeout": "30s",
|
||||||
"llm_providers": [],
|
"model_providers": model_providers or [],
|
||||||
"protocol": "openai",
|
"protocol": "openai",
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_gateway_listener = {
|
prompt_gateway_listener = {
|
||||||
"name": "ingress_traffic",
|
"name": "ingress_traffic",
|
||||||
|
"type": "prompt_listener",
|
||||||
"port": 10000,
|
"port": 10000,
|
||||||
"address": "0.0.0.0",
|
"address": "0.0.0.0",
|
||||||
"timeout": "30s",
|
"timeout": "30s",
|
||||||
|
|
@ -74,10 +76,10 @@ def convert_legacy_llm_providers(
|
||||||
llm_gateway_listener["timeout"] = egress_traffic.get(
|
llm_gateway_listener["timeout"] = egress_traffic.get(
|
||||||
"timeout", llm_gateway_listener["timeout"]
|
"timeout", llm_gateway_listener["timeout"]
|
||||||
)
|
)
|
||||||
if llm_providers is None or llm_providers == []:
|
if model_providers is None or model_providers == []:
|
||||||
raise ValueError("llm_providers cannot be empty when using legacy format")
|
raise ValueError("model_providers cannot be empty when using legacy format")
|
||||||
|
|
||||||
llm_gateway_listener["llm_providers"] = llm_providers
|
llm_gateway_listener["model_providers"] = model_providers
|
||||||
updated_listeners.append(llm_gateway_listener)
|
updated_listeners.append(llm_gateway_listener)
|
||||||
|
|
||||||
if ingress_traffic and ingress_traffic != {}:
|
if ingress_traffic and ingress_traffic != {}:
|
||||||
|
|
@ -94,15 +96,16 @@ def convert_legacy_llm_providers(
|
||||||
|
|
||||||
return updated_listeners, llm_gateway_listener, prompt_gateway_listener
|
return updated_listeners, llm_gateway_listener, prompt_gateway_listener
|
||||||
|
|
||||||
llm_provider_set = False
|
model_provider_set = False
|
||||||
for listener in listeners:
|
for listener in listeners:
|
||||||
if listener.get("llm_providers") is not None:
|
if listener.get("type") == "model_listener":
|
||||||
if llm_provider_set:
|
if model_provider_set:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Currently only one listener can have llm_providers set"
|
"Currently only one listener can have model_providers set"
|
||||||
)
|
)
|
||||||
|
listener["model_providers"] = model_providers or []
|
||||||
|
model_provider_set = True
|
||||||
llm_gateway_listener = listener
|
llm_gateway_listener = listener
|
||||||
llm_provider_set = True
|
|
||||||
|
|
||||||
return listeners, llm_gateway_listener, prompt_gateway_listener
|
return listeners, llm_gateway_listener, prompt_gateway_listener
|
||||||
|
|
||||||
|
|
@ -113,7 +116,7 @@ def get_llm_provider_access_keys(arch_config_file):
|
||||||
arch_config_yaml = yaml.safe_load(arch_config)
|
arch_config_yaml = yaml.safe_load(arch_config)
|
||||||
|
|
||||||
access_key_list = []
|
access_key_list = []
|
||||||
listeners, _, _ = convert_legacy_llm_providers(
|
listeners, _, _ = convert_legacy_listeners(
|
||||||
arch_config_yaml.get("listeners"), arch_config_yaml.get("llm_providers")
|
arch_config_yaml.get("listeners"), arch_config_yaml.get("llm_providers")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -350,7 +350,7 @@ def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case):
|
||||||
|
|
||||||
|
|
||||||
def test_convert_legacy_llm_providers():
|
def test_convert_legacy_llm_providers():
|
||||||
from cli.utils import convert_legacy_llm_providers
|
from cli.utils import convert_legacy_listeners
|
||||||
|
|
||||||
listeners = {
|
listeners = {
|
||||||
"ingress_traffic": {
|
"ingress_traffic": {
|
||||||
|
|
@ -373,7 +373,7 @@ def test_convert_legacy_llm_providers():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
updated_providers, llm_gateway, prompt_gateway = convert_legacy_llm_providers(
|
updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||||
listeners, llm_providers
|
listeners, llm_providers
|
||||||
)
|
)
|
||||||
assert isinstance(updated_providers, list)
|
assert isinstance(updated_providers, list)
|
||||||
|
|
@ -425,7 +425,7 @@ def test_convert_legacy_llm_providers():
|
||||||
|
|
||||||
|
|
||||||
def test_convert_legacy_llm_providers_no_prompt_gateway():
|
def test_convert_legacy_llm_providers_no_prompt_gateway():
|
||||||
from cli.utils import convert_legacy_llm_providers
|
from cli.utils import convert_legacy_listeners
|
||||||
|
|
||||||
listeners = {
|
listeners = {
|
||||||
"egress_traffic": {
|
"egress_traffic": {
|
||||||
|
|
@ -442,7 +442,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
updated_providers, llm_gateway, prompt_gateway = convert_legacy_llm_providers(
|
updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||||
listeners, llm_providers
|
listeners, llm_providers
|
||||||
)
|
)
|
||||||
assert isinstance(updated_providers, list)
|
assert isinstance(updated_providers, list)
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
|
||||||
let arch_config = Arc::new(config);
|
let arch_config = Arc::new(config);
|
||||||
|
|
||||||
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
|
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||||
|
|
||||||
|
|
@ -87,11 +87,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let routing_llm_provider = arch_config
|
let routing_llm_provider = arch_config
|
||||||
.routing
|
.routing
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|r| r.llm_provider.clone())
|
.and_then(|r| r.model_provider.clone())
|
||||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||||
|
|
||||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||||
arch_config.llm_providers.clone(),
|
arch_config.model_providers.clone(),
|
||||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||||
routing_model_name,
|
routing_model_name,
|
||||||
routing_llm_provider,
|
routing_llm_provider,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ use crate::api::open_ai::{
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Routing {
|
pub struct Routing {
|
||||||
pub llm_provider: Option<String>,
|
pub model_provider: Option<String>,
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ pub struct Listener {
|
||||||
pub struct Configuration {
|
pub struct Configuration {
|
||||||
pub version: String,
|
pub version: String,
|
||||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||||
pub llm_providers: Vec<LlmProvider>,
|
pub model_providers: Vec<LlmProvider>,
|
||||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||||
pub overrides: Option<Overrides>,
|
pub overrides: Option<Overrides>,
|
||||||
pub system_prompt: Option<String>,
|
pub system_prompt: Option<String>,
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ impl RootContext for FilterContext {
|
||||||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
||||||
self.overrides = Rc::new(config.overrides);
|
self.overrides = Rc::new(config.overrides);
|
||||||
|
|
||||||
match config.llm_providers.try_into() {
|
match config.model_providers.try_into() {
|
||||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||||
Err(err) => panic!("{err}"),
|
Err(err) => panic!("{err}"),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
version: v0.1.0
|
version: v0.3.0
|
||||||
|
|
||||||
listeners:
|
listeners:
|
||||||
egress_traffic:
|
- type: model_listener
|
||||||
|
name: model_listener_1
|
||||||
address: 0.0.0.0
|
address: 0.0.0.0
|
||||||
port: 12000
|
port: 12000
|
||||||
message_format: openai
|
|
||||||
timeout: 30s
|
|
||||||
|
|
||||||
llm_providers:
|
model_providers:
|
||||||
|
|
||||||
- access_key: $OPENAI_API_KEY
|
- access_key: $OPENAI_API_KEY
|
||||||
model: openai/gpt-4o-mini
|
model: openai/gpt-4o-mini
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ endpoints:
|
||||||
port: 8001
|
port: 8001
|
||||||
listeners:
|
listeners:
|
||||||
- address: 0.0.0.0
|
- address: 0.0.0.0
|
||||||
llm_providers:
|
model_providers:
|
||||||
- access_key: $OPENAI_API_KEY
|
- access_key: $OPENAI_API_KEY
|
||||||
default: true
|
default: true
|
||||||
model: gpt-4o
|
model: gpt-4o
|
||||||
|
|
@ -38,7 +38,12 @@ listeners:
|
||||||
port: 10000
|
port: 10000
|
||||||
protocol: openai
|
protocol: openai
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
llm_providers:
|
model_aliases:
|
||||||
|
arch.summarize.v1:
|
||||||
|
target: gpt-4o
|
||||||
|
arch.v1:
|
||||||
|
target: mistral-8x7b
|
||||||
|
model_providers:
|
||||||
- access_key: $OPENAI_API_KEY
|
- access_key: $OPENAI_API_KEY
|
||||||
default: true
|
default: true
|
||||||
model: gpt-4o
|
model: gpt-4o
|
||||||
|
|
@ -56,11 +61,6 @@ llm_providers:
|
||||||
port: 80
|
port: 80
|
||||||
protocol: http
|
protocol: http
|
||||||
provider_interface: mistral
|
provider_interface: mistral
|
||||||
model_aliases:
|
|
||||||
arch.summarize.v1:
|
|
||||||
target: gpt-4o
|
|
||||||
arch.v1:
|
|
||||||
target: mistral-8x7b
|
|
||||||
overrides:
|
overrides:
|
||||||
prompt_target_intent_matching_threshold: 0.6
|
prompt_target_intent_matching_threshold: 0.6
|
||||||
prompt_guards:
|
prompt_guards:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue