diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 7929a7a1..7719d364 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -15,9 +15,30 @@ properties: items: type: object listeners: - anyOf: + oneOf: - type: array - - type: object + additionalProperties: false + items: + type: object + properties: + name: + type: string + port: + type: integer + address: + type: string + timeout: + type: string + type: + type: string + enum: + - model_listener + - prompt_listener + - agent_listener + required: + - type + - name + - type: object # deprecated legacy format, use list format instead additionalProperties: false properties: ingress_traffic: @@ -69,7 +90,52 @@ properties: additionalProperties: false required: - endpoint - llm_providers: + + model_providers: + type: array + items: + type: object + properties: + name: + type: string + access_key: + type: string + model: + type: string + default: + type: boolean + base_url: + type: string + http_host: + type: string + provider_interface: + type: string + enum: + - arch + - claude + - deepseek + - groq + - mistral + - openai + - gemini + routing_preferences: + type: array + items: + type: object + properties: + name: + type: string + description: + type: string + additionalProperties: false + required: + - name + - description + additionalProperties: false + required: + - model + + llm_providers: # deprecated for legacy support, use model_providers instead type: array items: type: object diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 642ca527..ec188e82 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -1,6 +1,6 @@ import json import os -from cli.utils import convert_legacy_llm_providers +from cli.utils import convert_legacy_listeners from jinja2 import Environment, FileSystemLoader import yaml from jsonschema import validate @@ -71,8 +71,17 @@ def validate_and_render_schema(): _ = yaml.safe_load(arch_config_schema) inferred_clusters = {} - listeners, llm_gateway, prompt_gateway = convert_legacy_llm_providers( - config_yaml.get("listeners"), config_yaml.get("llm_providers") + # Convert legacy llm_providers to model_providers + if "llm_providers" in config_yaml: + if "model_providers" in config_yaml: + raise Exception( + "Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead" + ) + config_yaml["model_providers"] = config_yaml["llm_providers"] + del config_yaml["llm_providers"] + + listeners, llm_gateway, prompt_gateway = convert_legacy_listeners( + config_yaml.get("listeners"), config_yaml.get("model_providers") ) config_yaml["listeners"] = listeners @@ -130,36 +139,39 @@ def validate_and_render_schema(): arch_tracing = config_yaml.get("tracing", {}) llms_with_endpoint = [] - updated_llm_providers = [] - llm_provider_name_set = set() + updated_model_providers = [] + model_provider_name_set = set() llms_with_usage = [] model_name_keys = set() model_usage_name_keys = set() for listener in listeners: - if listener.get("llm_providers") is None or listener.get("llm_providers") == []: + if ( + listener.get("model_providers") is None + or listener.get("model_providers") == [] + ): continue - print("Processing listener with llm_providers: ", listener) + print("Processing listener with model_providers: ", listener) name = listener.get("name", None) - for llm_provider in listener.get("llm_providers", []): - if llm_provider.get("usage", None): - llms_with_usage.append(llm_provider["name"]) - if llm_provider.get("name") in llm_provider_name_set: + for model_provider in listener.get("model_providers", []): + if model_provider.get("usage", None): + llms_with_usage.append(model_provider["name"]) + if model_provider.get("name") in model_provider_name_set: raise Exception( - f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" + f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider" ) - model_name = llm_provider.get("model") + model_name = model_provider.get("model") if model_name in model_name_keys: raise Exception( - f"Duplicate model name {model_name}, please provide unique model name for each llm_provider" + f"Duplicate model name {model_name}, please provide unique model name for each model_provider" ) model_name_keys.add(model_name) - if llm_provider.get("name") is None: - llm_provider["name"] = model_name + if model_provider.get("name") is None: + model_provider["name"] = model_name - llm_provider_name_set.add(llm_provider.get("name")) + model_provider_name_set.add(model_provider.get("name")) model_name_tokens = model_name.split("/") if len(model_name_tokens) < 2: @@ -171,7 +183,7 @@ def validate_and_render_schema(): # Validate azure_openai and ollama provider requires base_url if ( provider == "azure_openai" or provider == "ollama" - ) and llm_provider.get("base_url") is None: + ) and model_provider.get("base_url") is None: raise Exception( f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" ) @@ -179,46 +191,48 @@ def validate_and_render_schema(): model_id = "/".join(model_name_tokens[1:]) if provider not in SUPPORTED_PROVIDERS: if ( - llm_provider.get("base_url", None) is None - or llm_provider.get("provider_interface", None) is None + model_provider.get("base_url", None) is None + or model_provider.get("provider_interface", None) is None ): raise Exception( f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" ) - provider = llm_provider.get("provider_interface", None) - elif llm_provider.get("provider_interface", None) is not None: + provider = model_provider.get("provider_interface", None) + elif model_provider.get("provider_interface", None) is not None: raise Exception( f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " ) if model_id in model_name_keys: raise Exception( - f"Duplicate model_id {model_id}, please provide unique model_id for each llm_provider" + f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider" ) model_name_keys.add(model_id) - for routing_preference in llm_provider.get("routing_preferences", []): + for routing_preference in model_provider.get("routing_preferences", []): if routing_preference.get("name") in model_usage_name_keys: raise Exception( f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference" ) model_usage_name_keys.add(routing_preference.get("name")) - llm_provider["model"] = model_id - llm_provider["provider_interface"] = provider - llm_provider_name_set.add(llm_provider.get("name")) - if llm_provider.get("provider") and llm_provider.get("provider_interface"): + model_provider["model"] = model_id + model_provider["provider_interface"] = provider + model_provider_name_set.add(model_provider.get("name")) + if model_provider.get("provider") and model_provider.get( + "provider_interface" + ): raise Exception( "Please provide either provider or provider_interface, not both" ) - if llm_provider.get("provider"): - provider = llm_provider["provider"] - llm_provider["provider_interface"] = provider - del llm_provider["provider"] - updated_llm_providers.append(llm_provider) + if model_provider.get("provider"): + provider = model_provider["provider"] + model_provider["provider_interface"] = provider + del model_provider["provider"] + updated_model_providers.append(model_provider) - if llm_provider.get("base_url", None): - base_url = llm_provider["base_url"] + if model_provider.get("base_url", None): + base_url = model_provider["base_url"] urlparse_result = urlparse(base_url) url_path = urlparse_result.path if url_path and url_path != "/": @@ -240,22 +254,30 @@ def validate_and_render_schema(): else: port = 443 endpoint = urlparse_result.hostname - llm_provider["endpoint"] = endpoint - llm_provider["port"] = port - llm_provider["protocol"] = protocol - llm_provider["cluster_name"] = ( + model_provider["endpoint"] = endpoint + model_provider["port"] = port + model_provider["protocol"] = protocol + model_provider["cluster_name"] = ( provider + "_" + endpoint ) # make name unique by appending endpoint - llms_with_endpoint.append(llm_provider) + llms_with_endpoint.append(model_provider) if len(model_usage_name_keys) > 0: - routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None) - if routing_llm_provider and routing_llm_provider not in llm_provider_name_set: + routing_model_provider = config_yaml.get("routing", {}).get( + "model_provider", None + ) + if ( + routing_model_provider + and routing_model_provider not in model_provider_name_set + ): raise Exception( - f"Routing llm_provider {routing_llm_provider} is not defined in llm_providers" + f"Routing model_provider {routing_model_provider} is not defined in model_providers" ) - if routing_llm_provider is None and "arch-router" not in llm_provider_name_set: - updated_llm_providers.append( + if ( + routing_model_provider is None + and "arch-router" not in model_provider_name_set + ): + updated_model_providers.append( { "name": "arch-router", "provider_interface": "arch", @@ -263,19 +285,19 @@ def validate_and_render_schema(): } ) - updated_llm_providers = [] + updated_model_providers = [] for listener in listeners: print("Processing listener: ", listener) - llm_providers = listener.get("llm_providers", None) - if llm_providers is not None and llm_providers != []: + model_providers = listener.get("model_providers", None) + if model_providers is not None and model_providers != []: print("processing egress traffic listener") - print("updated_llm_providers: ", updated_llm_providers) - if updated_llm_providers is not None and updated_llm_providers != []: + print("updated_model_providers: ", updated_model_providers) + if updated_model_providers is not None and updated_model_providers != []: raise Exception( - "Please provide llm_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with llm_providers" + "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" ) - updated_llm_providers = deepcopy(llm_providers) - config_yaml["llm_providers"] = updated_llm_providers + updated_model_providers = deepcopy(model_providers) + config_yaml["model_providers"] = updated_model_providers # Validate model aliases if present if "model_aliases" in config_yaml: @@ -317,7 +339,7 @@ def validate_and_render_schema(): "arch_config": arch_config_string, "arch_llm_config": arch_llm_config_string, "arch_clusters": inferred_clusters, - "arch_llm_providers": updated_llm_providers, + "arch_model_providers": updated_model_providers, "arch_tracing": arch_tracing, "local_llms": llms_with_endpoint, "agent_orchestrator": agent_orchestrator, diff --git a/arch/tools/cli/core.py b/arch/tools/cli/core.py index 7bfa9c70..6cd028e7 100644 --- a/arch/tools/cli/core.py +++ b/arch/tools/cli/core.py @@ -5,7 +5,7 @@ import time import sys import yaml -from cli.utils import convert_legacy_llm_providers, getLogger +from cli.utils import convert_legacy_listeners, getLogger from cli.consts import ( ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME, @@ -37,7 +37,7 @@ def _get_gateway_ports(arch_config_file: str) -> list[int]: print("arch config dict json string: ", json.dumps(arch_config_dict)) - listeners, _, _ = convert_legacy_llm_providers( + listeners, _, _ = convert_legacy_listeners( arch_config_dict.get("listeners"), arch_config_dict.get("llm_providers") ) diff --git a/arch/tools/cli/utils.py b/arch/tools/cli/utils.py index 0d34f9be..d84322a3 100644 --- a/arch/tools/cli/utils.py +++ b/arch/tools/cli/utils.py @@ -37,20 +37,22 @@ def has_ingress_listener(arch_config_file): return False -def convert_legacy_llm_providers( - listeners: dict | list, llm_providers: list | None +def convert_legacy_listeners( + listeners: dict | list, model_providers: list | None ) -> tuple[list, dict | None, dict | None]: llm_gateway_listener = { "name": "egress_traffic", + "type": "model_listener", "port": 12000, "address": "0.0.0.0", "timeout": "30s", - "llm_providers": [], + "model_providers": model_providers or [], "protocol": "openai", } prompt_gateway_listener = { "name": "ingress_traffic", + "type": "prompt_listener", "port": 10000, "address": "0.0.0.0", "timeout": "30s", @@ -74,10 +76,10 @@ def convert_legacy_llm_providers( llm_gateway_listener["timeout"] = egress_traffic.get( "timeout", llm_gateway_listener["timeout"] ) - if llm_providers is None or llm_providers == []: - raise ValueError("llm_providers cannot be empty when using legacy format") + if model_providers is None or model_providers == []: + raise ValueError("model_providers cannot be empty when using legacy format") - llm_gateway_listener["llm_providers"] = llm_providers + llm_gateway_listener["model_providers"] = model_providers updated_listeners.append(llm_gateway_listener) if ingress_traffic and ingress_traffic != {}: @@ -94,15 +96,16 @@ def convert_legacy_llm_providers( return updated_listeners, llm_gateway_listener, prompt_gateway_listener - llm_provider_set = False + model_provider_set = False for listener in listeners: - if listener.get("llm_providers") is not None: - if llm_provider_set: + if listener.get("type") == "model_listener": + if model_provider_set: raise ValueError( - "Currently only one listener can have llm_providers set" + "Currently only one listener can have model_providers set" ) + listener["model_providers"] = model_providers or [] + model_provider_set = True llm_gateway_listener = listener - llm_provider_set = True return listeners, llm_gateway_listener, prompt_gateway_listener @@ -113,7 +116,7 @@ def get_llm_provider_access_keys(arch_config_file): arch_config_yaml = yaml.safe_load(arch_config) access_key_list = [] - listeners, _, _ = convert_legacy_llm_providers( + listeners, _, _ = convert_legacy_listeners( arch_config_yaml.get("listeners"), arch_config_yaml.get("llm_providers") ) diff --git a/arch/tools/test/test_config_generator.py b/arch/tools/test/test_config_generator.py index 43b7bd45..0e9d6d3b 100644 --- a/arch/tools/test/test_config_generator.py +++ b/arch/tools/test/test_config_generator.py @@ -350,7 +350,7 @@ def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case): def test_convert_legacy_llm_providers(): - from cli.utils import convert_legacy_llm_providers + from cli.utils import convert_legacy_listeners listeners = { "ingress_traffic": { @@ -373,7 +373,7 @@ def test_convert_legacy_llm_providers(): } ] - updated_providers, llm_gateway, prompt_gateway = convert_legacy_llm_providers( + updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( listeners, llm_providers ) assert isinstance(updated_providers, list) @@ -425,7 +425,7 @@ def test_convert_legacy_llm_providers(): def test_convert_legacy_llm_providers_no_prompt_gateway(): - from cli.utils import convert_legacy_llm_providers + from cli.utils import convert_legacy_listeners listeners = { "egress_traffic": { @@ -442,7 +442,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway(): } ] - updated_providers, llm_gateway, prompt_gateway = convert_legacy_llm_providers( + updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( listeners, llm_providers ) assert isinstance(updated_providers, list) diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 897d9eb9..19b5004d 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -62,7 +62,7 @@ async fn main() -> Result<(), Box> { let arch_config = Arc::new(config); - let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone())); + let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone())); let agents_list = Arc::new(RwLock::new(arch_config.agents.clone())); let listeners = Arc::new(RwLock::new(arch_config.listeners.clone())); @@ -87,11 +87,11 @@ async fn main() -> Result<(), Box> { let routing_llm_provider = arch_config .routing .as_ref() - .and_then(|r| r.llm_provider.clone()) + .and_then(|r| r.model_provider.clone()) .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); let router_service: Arc = Arc::new(RouterService::new( - arch_config.llm_providers.clone(), + arch_config.model_providers.clone(), llm_provider_url.clone() + CHAT_COMPLETIONS_PATH, routing_model_name, routing_llm_provider, diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 4ba55dbb..1f8a8e24 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -9,7 +9,7 @@ use crate::api::open_ai::{ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Routing { - pub llm_provider: Option, + pub model_provider: Option, pub model: Option, } @@ -46,7 +46,7 @@ pub struct Listener { pub struct Configuration { pub version: String, pub endpoints: Option>, - pub llm_providers: Vec, + pub model_providers: Vec, pub model_aliases: Option>, pub overrides: Option, pub system_prompt: Option, diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 258a1a1c..2b8e1a95 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -74,7 +74,7 @@ impl RootContext for FilterContext { ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); self.overrides = Rc::new(config.overrides); - match config.llm_providers.try_into() { + match config.model_providers.try_into() { Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), Err(err) => panic!("{err}"), } diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index 176f53e9..b617958f 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -1,13 +1,12 @@ -version: v0.1.0 +version: v0.3.0 listeners: - egress_traffic: + - type: model_listener + name: model_listener_1 address: 0.0.0.0 port: 12000 - message_format: openai - timeout: 30s -llm_providers: +model_providers: - access_key: $OPENAI_API_KEY model: openai/gpt-4o-mini diff --git a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml index 7470c56c..8bad7a1b 100644 --- a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml @@ -11,7 +11,7 @@ endpoints: port: 8001 listeners: - address: 0.0.0.0 - llm_providers: + model_providers: - access_key: $OPENAI_API_KEY default: true model: gpt-4o @@ -38,7 +38,12 @@ listeners: port: 10000 protocol: openai timeout: 5s -llm_providers: +model_aliases: + arch.summarize.v1: + target: gpt-4o + arch.v1: + target: mistral-8x7b +model_providers: - access_key: $OPENAI_API_KEY default: true model: gpt-4o @@ -56,11 +61,6 @@ llm_providers: port: 80 protocol: http provider_interface: mistral -model_aliases: - arch.summarize.v1: - target: gpt-4o - arch.v1: - target: mistral-8x7b overrides: prompt_target_intent_matching_threshold: 0.6 prompt_guards: