draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI

This commit is contained in:
Salman Paracha 2025-09-17 22:47:33 -07:00
parent b56311f458
commit 79ff4bb164
7 changed files with 170 additions and 24 deletions

View file

@ -127,7 +127,7 @@ static_resources:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% set llm_cluster_name = provider.cluster_name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
@ -421,7 +421,7 @@ static_resources:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% set llm_cluster_name = provider.cluster_name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
@ -576,6 +576,82 @@ static_resources:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: xai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.x.ai
port_value: 443
hostname: "api.x.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.x.ai
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: together_ai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.together.xyz
port_value: 443
hostname: "api.together.xyz"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.together.xyz
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: lambda_ai
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.lambda.ai
port_value: 443
hostname: "api.lambda.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.lambda.ai
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: gemini
connect_timeout: 0.5s
type: LOGICAL_DNS
@ -742,13 +818,13 @@ static_resources:
{% endfor %}
{% for local_llm_provider in local_llms %}
- name: {{ local_llm_provider.name }}
- name: {{ local_llm_provider.cluster_name }}
connect_timeout: 0.5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ local_llm_provider.name }}
cluster_name: {{ local_llm_provider.cluster_name }}
endpoints:
- lb_endpoints:
- endpoint:

View file

@ -14,6 +14,10 @@ SUPPORTED_PROVIDERS = [
"openai",
"gemini",
"anthropic",
"together_ai",
"lambda_ai",
"azure_openai",
"xai",
]
@ -92,15 +96,12 @@ def validate_and_render_schema():
arch_tracing = config_yaml.get("tracing", {})
llms_with_endpoint = []
updated_llm_providers = []
llm_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
for llm_provider in config_yaml["llm_providers"]:
if llm_provider.get("usage", None):
llms_with_usage.append(llm_provider["name"])
if llm_provider.get("name") in llm_provider_name_set:
raise Exception(
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
@ -111,10 +112,13 @@ def validate_and_render_schema():
raise Exception(
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
)
model_name_keys.add(model_name)
if llm_provider.get("name") is None:
llm_provider["name"] = model_name
llm_provider_name_set.add(llm_provider.get("name"))
model_name_tokens = model_name.split("/")
if len(model_name_tokens) < 2:
raise Exception(
@ -151,16 +155,6 @@ def validate_and_render_schema():
llm_provider["model"] = model_id
llm_provider["provider_interface"] = provider
llm_provider_name_set.add(llm_provider.get("name"))
provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
raise Exception(
"Please provide either provider or provider_interface, not both"
)
if llm_provider.get("provider"):
provider = llm_provider["provider"]
llm_provider["provider_interface"] = provider
del llm_provider["provider"]
updated_llm_providers.append(llm_provider)
if llm_provider.get("base_url", None):
@ -189,6 +183,9 @@ def validate_and_render_schema():
llm_provider["endpoint"] = endpoint
llm_provider["port"] = port
llm_provider["protocol"] = protocol
llm_provider["cluster_name"] = (
provider + "_" + endpoint
) # make name unique by appending endpoint
llms_with_endpoint.append(llm_provider)
if len(model_usage_name_keys) > 0:

View file

@ -167,6 +167,14 @@ pub enum LlmProviderType {
OpenAI,
#[serde(rename = "gemini")]
Gemini,
#[serde(rename = "xai")]
XAI,
#[serde(rename = "together_ai")]
TogetherAI,
#[serde(rename = "lambda_ai")]
LambdaAI,
#[serde(rename = "azure_openai")]
AzureOpenAI,
}
impl Display for LlmProviderType {
@ -179,6 +187,10 @@ impl Display for LlmProviderType {
LlmProviderType::Gemini => write!(f, "gemini"),
LlmProviderType::Mistral => write!(f, "mistral"),
LlmProviderType::OpenAI => write!(f, "openai"),
LlmProviderType::XAI => write!(f, "xai"),
LlmProviderType::TogetherAI => write!(f, "together_ai"),
LlmProviderType::LambdaAI => write!(f, "lambda_ai"),
LlmProviderType::AzureOpenAI => write!(f, "azure_openai"),
}
}
}
@ -217,6 +229,7 @@ pub struct LlmProvider {
pub rate_limits: Option<LlmRatelimit>,
pub usage: Option<String>,
pub routing_preferences: Option<Vec<RoutingPreference>>,
pub cluster_name: Option<String>,
}
pub trait IntoModels {
@ -256,6 +269,7 @@ impl Default for LlmProvider {
rate_limits: None,
usage: None,
routing_preferences: None,
cluster_name: None,
}
}
}

View file

@ -62,7 +62,7 @@ impl SupportedAPIs {
}
}
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String {
let default_endpoint = "/v1/chat/completions".to_string();
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
@ -80,6 +80,13 @@ impl SupportedAPIs {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()

View file

@ -13,6 +13,10 @@ pub enum ProviderId {
Anthropic,
GitHub,
Arch,
AzureOpenAI,
XAI,
TogetherAI,
LambdaAI,
}
impl From<&str> for ProviderId {
@ -26,6 +30,10 @@ impl From<&str> for ProviderId {
"anthropic" => ProviderId::Anthropic,
"github" => ProviderId::GitHub,
"arch" => ProviderId::Arch,
"azure_openai" => ProviderId::AzureOpenAI,
"xai" => ProviderId::XAI,
"together_ai" => ProviderId::TogetherAI,
"lambda_ai" => ProviderId::LambdaAI,
_ => panic!("Unknown provider: {}", value),
}
}
@ -40,8 +48,31 @@ impl ProviderId {
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::LambdaAI,
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::LambdaAI,
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
}
}
}
@ -57,6 +88,10 @@ impl Display for ProviderId {
ProviderId::Anthropic => write!(f, "Anthropic"),
ProviderId::GitHub => write!(f, "GitHub"),
ProviderId::Arch => write!(f, "Arch"),
ProviderId::AzureOpenAI => write!(f, "azure_openai"),
ProviderId::XAI => write!(f, "xai"),
ProviderId::TogetherAI => write!(f, "together_ai"),
ProviderId::LambdaAI => write!(f, "lambda_ai"),
}
}
}

View file

@ -98,8 +98,14 @@ impl StreamContext {
fn update_upstream_path(&mut self, request_path: &str) {
let hermes_provider_id = self.llm_provider().to_provider_id();
if let Some(api) = &self.client_api {
let target_endpoint =
api.target_endpoint_for_provider(&hermes_provider_id, request_path);
let target_endpoint = api.target_endpoint_for_provider(
&hermes_provider_id,
request_path,
self.llm_provider()
.model
.as_ref()
.unwrap_or(&"".to_string()),
);
if target_endpoint != request_path {
self.set_http_request_header(":path", Some(&target_endpoint));
}
@ -622,7 +628,12 @@ impl HttpContext for StreamContext {
if self.llm_provider().endpoint.is_some() {
self.add_http_request_header(
ARCH_ROUTING_HEADER,
&self.llm_provider().name.to_string(),
&self
.llm_provider()
.cluster_name
.as_ref()
.unwrap()
.to_string(),
);
} else {
self.add_http_request_header(

View file

@ -25,6 +25,12 @@ llm_providers:
- model: anthropic/claude-3-haiku-20240307
access_key: $ANTHROPIC_API_KEY
# Azure OpenAI Models
- model: azure_openai/gpt-5-mini
access_key: $AZURE_API_KEY
base_url: https://katanemo.openai.azure.com
# Model aliases - friendly names that map to actual provider names
model_aliases: