diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 5c2fd420..a00bb19d 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -127,7 +127,7 @@ static_resources: {% for provider in arch_llm_providers %} # if endpoint is set then use custom cluster for upstream llm {% if provider.endpoint %} - {% set llm_cluster_name = provider.name %} + {% set llm_cluster_name = provider.cluster_name %} {% else %} {% set llm_cluster_name = provider.provider_interface %} {% endif %} @@ -421,7 +421,7 @@ static_resources: {% for provider in arch_llm_providers %} # if endpoint is set then use custom cluster for upstream llm {% if provider.endpoint %} - {% set llm_cluster_name = provider.name %} + {% set llm_cluster_name = provider.cluster_name %} {% else %} {% set llm_cluster_name = provider.provider_interface %} {% endif %} @@ -576,6 +576,82 @@ static_resources: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + - name: xai + connect_timeout: 0.5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: xai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.x.ai + port_value: 443 + hostname: "api.x.ai" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.x.ai + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + + - name: together_ai + connect_timeout: 0.5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: xai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.together.xyz + port_value: 443 + hostname: "api.together.xyz" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.together.xyz + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + + - name: lambda_ai + connect_timeout: 0.5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: xai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.lambda.ai + port_value: 443 + hostname: "api.lambda.ai" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.lambda.ai + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + + - name: gemini connect_timeout: 0.5s type: LOGICAL_DNS @@ -742,13 +818,13 @@ static_resources: {% endfor %} {% for local_llm_provider in local_llms %} - - name: {{ local_llm_provider.name }} + - name: {{ local_llm_provider.cluster_name }} connect_timeout: 0.5s type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: - cluster_name: {{ local_llm_provider.name }} + cluster_name: {{ local_llm_provider.cluster_name }} endpoints: - lb_endpoints: - endpoint: diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 1563dd4a..d6d746e9 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -14,6 +14,10 @@ SUPPORTED_PROVIDERS = [ "openai", "gemini", "anthropic", + "together_ai", + "lambda_ai", + "azure_openai", + "xai", ] @@ -92,15 +96,12 @@ def validate_and_render_schema(): arch_tracing = config_yaml.get("tracing", {}) llms_with_endpoint = [] - updated_llm_providers = [] + llm_provider_name_set = set() - llms_with_usage = [] model_name_keys = set() model_usage_name_keys = set() for llm_provider in config_yaml["llm_providers"]: - if llm_provider.get("usage", None): - llms_with_usage.append(llm_provider["name"]) if llm_provider.get("name") in llm_provider_name_set: raise Exception( f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" @@ -111,10 +112,13 @@ def validate_and_render_schema(): raise Exception( f"Duplicate model name {model_name}, please provide unique model name for each llm_provider" ) + model_name_keys.add(model_name) if llm_provider.get("name") is None: llm_provider["name"] = model_name + llm_provider_name_set.add(llm_provider.get("name")) + model_name_tokens = model_name.split("/") if len(model_name_tokens) < 2: raise Exception( @@ -151,16 +155,6 @@ def validate_and_render_schema(): llm_provider["model"] = model_id llm_provider["provider_interface"] = provider - llm_provider_name_set.add(llm_provider.get("name")) - provider = None - if llm_provider.get("provider") and llm_provider.get("provider_interface"): - raise Exception( - "Please provide either provider or provider_interface, not both" - ) - if llm_provider.get("provider"): - provider = llm_provider["provider"] - llm_provider["provider_interface"] = provider - del llm_provider["provider"] updated_llm_providers.append(llm_provider) if llm_provider.get("base_url", None): @@ -189,6 +183,9 @@ def validate_and_render_schema(): llm_provider["endpoint"] = endpoint llm_provider["port"] = port llm_provider["protocol"] = protocol + llm_provider["cluster_name"] = ( + provider + "_" + endpoint + ) # make name unique by appending endpoint llms_with_endpoint.append(llm_provider) if len(model_usage_name_keys) > 0: diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 81c2db4f..efd99704 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -167,6 +167,14 @@ pub enum LlmProviderType { OpenAI, #[serde(rename = "gemini")] Gemini, + #[serde(rename = "xai")] + XAI, + #[serde(rename = "together_ai")] + TogetherAI, + #[serde(rename = "lambda_ai")] + LambdaAI, + #[serde(rename = "azure_openai")] + AzureOpenAI, } impl Display for LlmProviderType { @@ -179,6 +187,10 @@ impl Display for LlmProviderType { LlmProviderType::Gemini => write!(f, "gemini"), LlmProviderType::Mistral => write!(f, "mistral"), LlmProviderType::OpenAI => write!(f, "openai"), + LlmProviderType::XAI => write!(f, "xai"), + LlmProviderType::TogetherAI => write!(f, "together_ai"), + LlmProviderType::LambdaAI => write!(f, "lambda_ai"), + LlmProviderType::AzureOpenAI => write!(f, "azure_openai"), } } } @@ -217,6 +229,7 @@ pub struct LlmProvider { pub rate_limits: Option, pub usage: Option, pub routing_preferences: Option>, + pub cluster_name: Option, } pub trait IntoModels { @@ -256,6 +269,7 @@ impl Default for LlmProvider { rate_limits: None, usage: None, routing_preferences: None, + cluster_name: None, } } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 5af51fe0..2b0f1ca8 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -62,7 +62,7 @@ impl SupportedAPIs { } } - pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String { + pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String { let default_endpoint = "/v1/chat/completions".to_string(); match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { @@ -80,6 +80,13 @@ impl SupportedAPIs { default_endpoint } } + ProviderId::AzureOpenAI => { + if request_path.starts_with("/v1/") { + format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id) + } else { + default_endpoint + } + } ProviderId::Gemini => { if request_path.starts_with("/v1/") { "/v1beta/openai/chat/completions".to_string() diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 26933adc..bc8691f2 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -13,6 +13,10 @@ pub enum ProviderId { Anthropic, GitHub, Arch, + AzureOpenAI, + XAI, + TogetherAI, + LambdaAI, } impl From<&str> for ProviderId { @@ -26,6 +30,10 @@ impl From<&str> for ProviderId { "anthropic" => ProviderId::Anthropic, "github" => ProviderId::GitHub, "arch" => ProviderId::Arch, + "azure_openai" => ProviderId::AzureOpenAI, + "xai" => ProviderId::XAI, + "together_ai" => ProviderId::TogetherAI, + "lambda_ai" => ProviderId::LambdaAI, _ => panic!("Unknown provider: {}", value), } } @@ -40,8 +48,31 @@ impl ProviderId { (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI-compatible providers only support OpenAI chat completions - (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + (ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI + | ProviderId::LambdaAI, + SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + + (ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI + | ProviderId::LambdaAI, + SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), } } } @@ -57,6 +88,10 @@ impl Display for ProviderId { ProviderId::Anthropic => write!(f, "Anthropic"), ProviderId::GitHub => write!(f, "GitHub"), ProviderId::Arch => write!(f, "Arch"), + ProviderId::AzureOpenAI => write!(f, "azure_openai"), + ProviderId::XAI => write!(f, "xai"), + ProviderId::TogetherAI => write!(f, "together_ai"), + ProviderId::LambdaAI => write!(f, "lambda_ai"), } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index da86296d..b1651e0c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -98,8 +98,14 @@ impl StreamContext { fn update_upstream_path(&mut self, request_path: &str) { let hermes_provider_id = self.llm_provider().to_provider_id(); if let Some(api) = &self.client_api { - let target_endpoint = - api.target_endpoint_for_provider(&hermes_provider_id, request_path); + let target_endpoint = api.target_endpoint_for_provider( + &hermes_provider_id, + request_path, + self.llm_provider() + .model + .as_ref() + .unwrap_or(&"".to_string()), + ); if target_endpoint != request_path { self.set_http_request_header(":path", Some(&target_endpoint)); } @@ -622,7 +628,12 @@ impl HttpContext for StreamContext { if self.llm_provider().endpoint.is_some() { self.add_http_request_header( ARCH_ROUTING_HEADER, - &self.llm_provider().name.to_string(), + &self + .llm_provider() + .cluster_name + .as_ref() + .unwrap() + .to_string(), ); } else { self.add_http_request_header( diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index 6a0fe25f..da26bb42 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -25,6 +25,12 @@ llm_providers: - model: anthropic/claude-3-haiku-20240307 access_key: $ANTHROPIC_API_KEY + # Azure OpenAI Models + + - model: azure_openai/gpt-5-mini + access_key: $AZURE_API_KEY + base_url: https://katanemo.openai.azure.com + # Model aliases - friendly names that map to actual provider names model_aliases: