mirror of
https://github.com/katanemo/plano.git
synced 2026-06-11 15:05:14 +02:00
draft commit to add support for xAI, TogehterAI, AzureOpenAI (#570)
* draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI * fixing failing tests and updating rederend config file * Update arch_config_with_aliases.yaml * adding the AZURE_API_KEY to the GH workflow for e2e * fixing GH secerts * adding valdiating for azure_openai --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-167.local>
This commit is contained in:
parent
b56311f458
commit
8d0b468345
12 changed files with 166 additions and 24 deletions
1
.github/workflows/e2e_archgw.yml
vendored
1
.github/workflows/e2e_archgw.yml
vendored
|
|
@ -38,6 +38,7 @@ jobs:
|
|||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
run: |
|
||||
docker compose up | tee &> archgw.logs &
|
||||
|
||||
|
|
|
|||
1
.github/workflows/e2e_tests.yml
vendored
1
.github/workflows/e2e_tests.yml
vendored
|
|
@ -31,6 +31,7 @@ jobs:
|
|||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
run: |
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ static_resources:
|
|||
{% for provider in arch_llm_providers %}
|
||||
# if endpoint is set then use custom cluster for upstream llm
|
||||
{% if provider.endpoint %}
|
||||
{% set llm_cluster_name = provider.name %}
|
||||
{% set llm_cluster_name = provider.cluster_name %}
|
||||
{% else %}
|
||||
{% set llm_cluster_name = provider.provider_interface %}
|
||||
{% endif %}
|
||||
|
|
@ -421,7 +421,7 @@ static_resources:
|
|||
{% for provider in arch_llm_providers %}
|
||||
# if endpoint is set then use custom cluster for upstream llm
|
||||
{% if provider.endpoint %}
|
||||
{% set llm_cluster_name = provider.name %}
|
||||
{% set llm_cluster_name = provider.cluster_name %}
|
||||
{% else %}
|
||||
{% set llm_cluster_name = provider.provider_interface %}
|
||||
{% endif %}
|
||||
|
|
@ -576,6 +576,56 @@ static_resources:
|
|||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
- name: xai
|
||||
connect_timeout: 0.5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: xai
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.x.ai
|
||||
port_value: 443
|
||||
hostname: "api.x.ai"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.x.ai
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
- name: together_ai
|
||||
connect_timeout: 0.5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: xai
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.together.xyz
|
||||
port_value: 443
|
||||
hostname: "api.together.xyz"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.together.xyz
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
- name: gemini
|
||||
connect_timeout: 0.5s
|
||||
type: LOGICAL_DNS
|
||||
|
|
@ -742,13 +792,13 @@ static_resources:
|
|||
{% endfor %}
|
||||
|
||||
{% for local_llm_provider in local_llms %}
|
||||
- name: {{ local_llm_provider.name }}
|
||||
- name: {{ local_llm_provider.cluster_name }}
|
||||
connect_timeout: 0.5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: {{ local_llm_provider.name }}
|
||||
cluster_name: {{ local_llm_provider.cluster_name }}
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ SUPPORTED_PROVIDERS = [
|
|||
"openai",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"azure_openai",
|
||||
"xai",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -92,15 +95,12 @@ def validate_and_render_schema():
|
|||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
||||
llms_with_endpoint = []
|
||||
|
||||
updated_llm_providers = []
|
||||
|
||||
llm_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
if llm_provider.get("usage", None):
|
||||
llms_with_usage.append(llm_provider["name"])
|
||||
if llm_provider.get("name") in llm_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
|
|
@ -111,16 +111,25 @@ def validate_and_render_schema():
|
|||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each llm_provider"
|
||||
)
|
||||
|
||||
model_name_keys.add(model_name)
|
||||
if llm_provider.get("name") is None:
|
||||
llm_provider["name"] = model_name
|
||||
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
# Validate azure_openai provider requires base_url
|
||||
if provider == "azure_openai" and llm_provider.get("base_url") is None:
|
||||
raise Exception(
|
||||
f"Provider 'azure_openai' requires 'base_url' to be set for model {model_name}"
|
||||
)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
|
|
@ -151,16 +160,6 @@ def validate_and_render_schema():
|
|||
|
||||
llm_provider["model"] = model_id
|
||||
llm_provider["provider_interface"] = provider
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
provider = None
|
||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
||||
raise Exception(
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
if llm_provider.get("provider"):
|
||||
provider = llm_provider["provider"]
|
||||
llm_provider["provider_interface"] = provider
|
||||
del llm_provider["provider"]
|
||||
updated_llm_providers.append(llm_provider)
|
||||
|
||||
if llm_provider.get("base_url", None):
|
||||
|
|
@ -189,6 +188,9 @@ def validate_and_render_schema():
|
|||
llm_provider["endpoint"] = endpoint
|
||||
llm_provider["port"] = port
|
||||
llm_provider["protocol"] = protocol
|
||||
llm_provider["cluster_name"] = (
|
||||
provider + "_" + endpoint
|
||||
) # make name unique by appending endpoint
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
|
|
|
|||
|
|
@ -167,6 +167,12 @@ pub enum LlmProviderType {
|
|||
OpenAI,
|
||||
#[serde(rename = "gemini")]
|
||||
Gemini,
|
||||
#[serde(rename = "xai")]
|
||||
XAI,
|
||||
#[serde(rename = "together_ai")]
|
||||
TogetherAI,
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAI,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -179,6 +185,9 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::Gemini => write!(f, "gemini"),
|
||||
LlmProviderType::Mistral => write!(f, "mistral"),
|
||||
LlmProviderType::OpenAI => write!(f, "openai"),
|
||||
LlmProviderType::XAI => write!(f, "xai"),
|
||||
LlmProviderType::TogetherAI => write!(f, "together_ai"),
|
||||
LlmProviderType::AzureOpenAI => write!(f, "azure_openai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -217,6 +226,7 @@ pub struct LlmProvider {
|
|||
pub rate_limits: Option<LlmRatelimit>,
|
||||
pub usage: Option<String>,
|
||||
pub routing_preferences: Option<Vec<RoutingPreference>>,
|
||||
pub cluster_name: Option<String>,
|
||||
}
|
||||
|
||||
pub trait IntoModels {
|
||||
|
|
@ -256,6 +266,7 @@ impl Default for LlmProvider {
|
|||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
cluster_name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ impl SupportedAPIs {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
|
|
@ -80,6 +80,13 @@ impl SupportedAPIs {
|
|||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ pub enum ProviderId {
|
|||
Anthropic,
|
||||
GitHub,
|
||||
Arch,
|
||||
AzureOpenAI,
|
||||
XAI,
|
||||
TogetherAI,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
|
|
@ -26,6 +29,9 @@ impl From<&str> for ProviderId {
|
|||
"anthropic" => ProviderId::Anthropic,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
"azure_openai" => ProviderId::AzureOpenAI,
|
||||
"xai" => ProviderId::XAI,
|
||||
"together_ai" => ProviderId::TogetherAI,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
|
|
@ -40,8 +46,29 @@ impl ProviderId {
|
|||
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI,
|
||||
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -57,6 +84,9 @@ impl Display for ProviderId {
|
|||
ProviderId::Anthropic => write!(f, "Anthropic"),
|
||||
ProviderId::GitHub => write!(f, "GitHub"),
|
||||
ProviderId::Arch => write!(f, "Arch"),
|
||||
ProviderId::AzureOpenAI => write!(f, "azure_openai"),
|
||||
ProviderId::XAI => write!(f, "xai"),
|
||||
ProviderId::TogetherAI => write!(f, "together_ai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,8 +98,14 @@ impl StreamContext {
|
|||
fn update_upstream_path(&mut self, request_path: &str) {
|
||||
let hermes_provider_id = self.llm_provider().to_provider_id();
|
||||
if let Some(api) = &self.client_api {
|
||||
let target_endpoint =
|
||||
api.target_endpoint_for_provider(&hermes_provider_id, request_path);
|
||||
let target_endpoint = api.target_endpoint_for_provider(
|
||||
&hermes_provider_id,
|
||||
request_path,
|
||||
self.llm_provider()
|
||||
.model
|
||||
.as_ref()
|
||||
.unwrap_or(&"".to_string()),
|
||||
);
|
||||
if target_endpoint != request_path {
|
||||
self.set_http_request_header(":path", Some(&target_endpoint));
|
||||
}
|
||||
|
|
@ -622,7 +628,12 @@ impl HttpContext for StreamContext {
|
|||
if self.llm_provider().endpoint.is_some() {
|
||||
self.add_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
&self.llm_provider().name.to_string(),
|
||||
&self
|
||||
.llm_provider()
|
||||
.cluster_name
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
} else {
|
||||
self.add_http_request_header(
|
||||
|
|
|
|||
|
|
@ -37,6 +37,12 @@ llm_providers:
|
|||
- access_key: $GEMINI_API_KEY
|
||||
model: gemini/gemini-1.5-pro-latest
|
||||
|
||||
- model: xai/grok-4-latest
|
||||
access_key: $GROK_API_KEY
|
||||
|
||||
- model: together_ai/openai/gpt-oss-20b
|
||||
access_key: $TOGETHER_API_KEY
|
||||
|
||||
- model: custom/test-model
|
||||
base_url: http://host.docker.internal:11223
|
||||
provider_interface: openai
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ listeners:
|
|||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
# OpenAI Models
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
|
@ -25,6 +26,12 @@ llm_providers:
|
|||
|
||||
- model: anthropic/claude-3-haiku-20240307
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
# Azure OpenAI Models
|
||||
|
||||
- model: azure_openai/gpt-5-mini
|
||||
access_key: $AZURE_API_KEY
|
||||
base_url: https://katanemo.openai.azure.com
|
||||
|
||||
|
||||
# Model aliases - friendly names that map to actual provider names
|
||||
model_aliases:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,16 @@ llm_providers:
|
|||
- model: mistral/mistral-7b-instruct
|
||||
base_url: http://mistral_local
|
||||
|
||||
# Model aliases - friendly names that map to actual provider names
|
||||
model_aliases:
|
||||
# Alias for summarization tasks -> fast/cheap model
|
||||
arch.summarize.v1:
|
||||
target: gpt-4o
|
||||
|
||||
# Alias for general purpose tasks -> latest model
|
||||
arch.v1:
|
||||
target: mistral-8x7b
|
||||
|
||||
# provides a way to override default settings for the arch system
|
||||
overrides:
|
||||
# By default Arch uses an NLI + embedding approach to match an incoming prompt to a prompt target.
|
||||
|
|
|
|||
|
|
@ -31,12 +31,18 @@ llm_providers:
|
|||
name: mistral/mistral-8x7b
|
||||
provider_interface: mistral
|
||||
- base_url: http://mistral_local
|
||||
cluster_name: mistral_mistral_local
|
||||
endpoint: mistral_local
|
||||
model: mistral-7b-instruct
|
||||
name: mistral/mistral-7b-instruct
|
||||
port: 80
|
||||
protocol: http
|
||||
provider_interface: mistral
|
||||
model_aliases:
|
||||
arch.summarize.v1:
|
||||
target: gpt-4o
|
||||
arch.v1:
|
||||
target: mistral-8x7b
|
||||
overrides:
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
prompt_guards:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue