diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 3d89854c..dda95493 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -38,6 +38,7 @@ jobs: MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} run: | docker compose up | tee &> archgw.logs & diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index df715e00..3319f145 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -31,6 +31,7 @@ jobs: MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} run: | python -mvenv venv source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 5c2fd420..5ee4c899 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -127,7 +127,7 @@ static_resources: {% for provider in arch_llm_providers %} # if endpoint is set then use custom cluster for upstream llm {% if provider.endpoint %} - {% set llm_cluster_name = provider.name %} + {% set llm_cluster_name = provider.cluster_name %} {% else %} {% set llm_cluster_name = provider.provider_interface %} {% endif %} @@ -421,7 +421,7 @@ static_resources: {% for provider in arch_llm_providers %} # if endpoint is set then use custom cluster for upstream llm {% if provider.endpoint %} - {% set llm_cluster_name = provider.name %} + {% set llm_cluster_name = provider.cluster_name %} {% else %} {% set llm_cluster_name = provider.provider_interface %} {% endif %} @@ -576,6 +576,56 @@ static_resources: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + - name: xai + connect_timeout: 0.5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: xai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.x.ai + port_value: 443 + hostname: "api.x.ai" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.x.ai + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + + - name: together_ai + connect_timeout: 0.5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: xai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.together.xyz + port_value: 443 + hostname: "api.together.xyz" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.together.xyz + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + - name: gemini connect_timeout: 0.5s type: LOGICAL_DNS @@ -742,13 +792,13 @@ static_resources: {% endfor %} {% for local_llm_provider in local_llms %} - - name: {{ local_llm_provider.name }} + - name: {{ local_llm_provider.cluster_name }} connect_timeout: 0.5s type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: - cluster_name: {{ local_llm_provider.name }} + cluster_name: {{ local_llm_provider.cluster_name }} endpoints: - lb_endpoints: - endpoint: diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 1563dd4a..01a85095 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -14,6 +14,9 @@ SUPPORTED_PROVIDERS = [ "openai", "gemini", "anthropic", + "together_ai", + "azure_openai", + "xai", ] @@ -92,15 +95,12 @@ def validate_and_render_schema(): arch_tracing = config_yaml.get("tracing", {}) llms_with_endpoint = [] - updated_llm_providers = [] + llm_provider_name_set = set() - llms_with_usage = [] model_name_keys = set() model_usage_name_keys = set() for llm_provider in config_yaml["llm_providers"]: - if llm_provider.get("usage", None): - llms_with_usage.append(llm_provider["name"]) if llm_provider.get("name") in llm_provider_name_set: raise Exception( f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" @@ -111,16 +111,25 @@ def validate_and_render_schema(): raise Exception( f"Duplicate model name {model_name}, please provide unique model name for each llm_provider" ) + model_name_keys.add(model_name) if llm_provider.get("name") is None: llm_provider["name"] = model_name + llm_provider_name_set.add(llm_provider.get("name")) + model_name_tokens = model_name.split("/") if len(model_name_tokens) < 2: raise Exception( f"Invalid model name {model_name}. Please provide model name in the format /." ) provider = model_name_tokens[0] + # Validate azure_openai provider requires base_url + if provider == "azure_openai" and llm_provider.get("base_url") is None: + raise Exception( + f"Provider 'azure_openai' requires 'base_url' to be set for model {model_name}" + ) + model_id = "/".join(model_name_tokens[1:]) if provider not in SUPPORTED_PROVIDERS: if ( @@ -151,16 +160,6 @@ def validate_and_render_schema(): llm_provider["model"] = model_id llm_provider["provider_interface"] = provider - llm_provider_name_set.add(llm_provider.get("name")) - provider = None - if llm_provider.get("provider") and llm_provider.get("provider_interface"): - raise Exception( - "Please provide either provider or provider_interface, not both" - ) - if llm_provider.get("provider"): - provider = llm_provider["provider"] - llm_provider["provider_interface"] = provider - del llm_provider["provider"] updated_llm_providers.append(llm_provider) if llm_provider.get("base_url", None): @@ -189,6 +188,9 @@ def validate_and_render_schema(): llm_provider["endpoint"] = endpoint llm_provider["port"] = port llm_provider["protocol"] = protocol + llm_provider["cluster_name"] = ( + provider + "_" + endpoint + ) # make name unique by appending endpoint llms_with_endpoint.append(llm_provider) if len(model_usage_name_keys) > 0: diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 81c2db4f..034e9148 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -167,6 +167,12 @@ pub enum LlmProviderType { OpenAI, #[serde(rename = "gemini")] Gemini, + #[serde(rename = "xai")] + XAI, + #[serde(rename = "together_ai")] + TogetherAI, + #[serde(rename = "azure_openai")] + AzureOpenAI, } impl Display for LlmProviderType { @@ -179,6 +185,9 @@ impl Display for LlmProviderType { LlmProviderType::Gemini => write!(f, "gemini"), LlmProviderType::Mistral => write!(f, "mistral"), LlmProviderType::OpenAI => write!(f, "openai"), + LlmProviderType::XAI => write!(f, "xai"), + LlmProviderType::TogetherAI => write!(f, "together_ai"), + LlmProviderType::AzureOpenAI => write!(f, "azure_openai"), } } } @@ -217,6 +226,7 @@ pub struct LlmProvider { pub rate_limits: Option, pub usage: Option, pub routing_preferences: Option>, + pub cluster_name: Option, } pub trait IntoModels { @@ -256,6 +266,7 @@ impl Default for LlmProvider { rate_limits: None, usage: None, routing_preferences: None, + cluster_name: None, } } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 5af51fe0..2b0f1ca8 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -62,7 +62,7 @@ impl SupportedAPIs { } } - pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String { + pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String { let default_endpoint = "/v1/chat/completions".to_string(); match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { @@ -80,6 +80,13 @@ impl SupportedAPIs { default_endpoint } } + ProviderId::AzureOpenAI => { + if request_path.starts_with("/v1/") { + format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id) + } else { + default_endpoint + } + } ProviderId::Gemini => { if request_path.starts_with("/v1/") { "/v1beta/openai/chat/completions".to_string() diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 26933adc..13ef4c6e 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -13,6 +13,9 @@ pub enum ProviderId { Anthropic, GitHub, Arch, + AzureOpenAI, + XAI, + TogetherAI, } impl From<&str> for ProviderId { @@ -26,6 +29,9 @@ impl From<&str> for ProviderId { "anthropic" => ProviderId::Anthropic, "github" => ProviderId::GitHub, "arch" => ProviderId::Arch, + "azure_openai" => ProviderId::AzureOpenAI, + "xai" => ProviderId::XAI, + "together_ai" => ProviderId::TogetherAI, _ => panic!("Unknown provider: {}", value), } } @@ -40,8 +46,29 @@ impl ProviderId { (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI-compatible providers only support OpenAI chat completions - (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + (ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI, + SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + + (ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI, + SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), } } } @@ -57,6 +84,9 @@ impl Display for ProviderId { ProviderId::Anthropic => write!(f, "Anthropic"), ProviderId::GitHub => write!(f, "GitHub"), ProviderId::Arch => write!(f, "Arch"), + ProviderId::AzureOpenAI => write!(f, "azure_openai"), + ProviderId::XAI => write!(f, "xai"), + ProviderId::TogetherAI => write!(f, "together_ai"), } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index da86296d..b1651e0c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -98,8 +98,14 @@ impl StreamContext { fn update_upstream_path(&mut self, request_path: &str) { let hermes_provider_id = self.llm_provider().to_provider_id(); if let Some(api) = &self.client_api { - let target_endpoint = - api.target_endpoint_for_provider(&hermes_provider_id, request_path); + let target_endpoint = api.target_endpoint_for_provider( + &hermes_provider_id, + request_path, + self.llm_provider() + .model + .as_ref() + .unwrap_or(&"".to_string()), + ); if target_endpoint != request_path { self.set_http_request_header(":path", Some(&target_endpoint)); } @@ -622,7 +628,12 @@ impl HttpContext for StreamContext { if self.llm_provider().endpoint.is_some() { self.add_http_request_header( ARCH_ROUTING_HEADER, - &self.llm_provider().name.to_string(), + &self + .llm_provider() + .cluster_name + .as_ref() + .unwrap() + .to_string(), ); } else { self.add_http_request_header( diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index f90643ff..176f53e9 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -37,6 +37,12 @@ llm_providers: - access_key: $GEMINI_API_KEY model: gemini/gemini-1.5-pro-latest + - model: xai/grok-4-latest + access_key: $GROK_API_KEY + + - model: together_ai/openai/gpt-oss-20b + access_key: $TOGETHER_API_KEY + - model: custom/test-model base_url: http://host.docker.internal:11223 provider_interface: openai diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index 6a0fe25f..d42583e4 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -8,6 +8,7 @@ listeners: timeout: 30s llm_providers: + # OpenAI Models - model: openai/gpt-4o-mini access_key: $OPENAI_API_KEY @@ -25,6 +26,12 @@ llm_providers: - model: anthropic/claude-3-haiku-20240307 access_key: $ANTHROPIC_API_KEY + # Azure OpenAI Models + + - model: azure_openai/gpt-5-mini + access_key: $AZURE_API_KEY + base_url: https://katanemo.openai.azure.com + # Model aliases - friendly names that map to actual provider names model_aliases: diff --git a/docs/source/resources/includes/arch_config_full_reference.yaml b/docs/source/resources/includes/arch_config_full_reference.yaml index 808baff1..c9d5e4ff 100644 --- a/docs/source/resources/includes/arch_config_full_reference.yaml +++ b/docs/source/resources/includes/arch_config_full_reference.yaml @@ -41,6 +41,16 @@ llm_providers: - model: mistral/mistral-7b-instruct base_url: http://mistral_local +# Model aliases - friendly names that map to actual provider names +model_aliases: + # Alias for summarization tasks -> fast/cheap model + arch.summarize.v1: + target: gpt-4o + + # Alias for general purpose tasks -> latest model + arch.v1: + target: mistral-8x7b + # provides a way to override default settings for the arch system overrides: # By default Arch uses an NLI + embedding approach to match an incoming prompt to a prompt target. diff --git a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml index 503f6a80..4c791e82 100644 --- a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml @@ -31,12 +31,18 @@ llm_providers: name: mistral/mistral-8x7b provider_interface: mistral - base_url: http://mistral_local + cluster_name: mistral_mistral_local endpoint: mistral_local model: mistral-7b-instruct name: mistral/mistral-7b-instruct port: 80 protocol: http provider_interface: mistral +model_aliases: + arch.summarize.v1: + target: gpt-4o + arch.v1: + target: mistral-8x7b overrides: prompt_target_intent_matching_threshold: 0.6 prompt_guards: