mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
draft commit to add support for xAI, LambdaAI, TogehterAI, AzureOpenAI
This commit is contained in:
parent
b56311f458
commit
79ff4bb164
7 changed files with 170 additions and 24 deletions
|
|
@ -167,6 +167,14 @@ pub enum LlmProviderType {
|
|||
OpenAI,
|
||||
#[serde(rename = "gemini")]
|
||||
Gemini,
|
||||
#[serde(rename = "xai")]
|
||||
XAI,
|
||||
#[serde(rename = "together_ai")]
|
||||
TogetherAI,
|
||||
#[serde(rename = "lambda_ai")]
|
||||
LambdaAI,
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAI,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -179,6 +187,10 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::Gemini => write!(f, "gemini"),
|
||||
LlmProviderType::Mistral => write!(f, "mistral"),
|
||||
LlmProviderType::OpenAI => write!(f, "openai"),
|
||||
LlmProviderType::XAI => write!(f, "xai"),
|
||||
LlmProviderType::TogetherAI => write!(f, "together_ai"),
|
||||
LlmProviderType::LambdaAI => write!(f, "lambda_ai"),
|
||||
LlmProviderType::AzureOpenAI => write!(f, "azure_openai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -217,6 +229,7 @@ pub struct LlmProvider {
|
|||
pub rate_limits: Option<LlmRatelimit>,
|
||||
pub usage: Option<String>,
|
||||
pub routing_preferences: Option<Vec<RoutingPreference>>,
|
||||
pub cluster_name: Option<String>,
|
||||
}
|
||||
|
||||
pub trait IntoModels {
|
||||
|
|
@ -256,6 +269,7 @@ impl Default for LlmProvider {
|
|||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
cluster_name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ impl SupportedAPIs {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
|
|
@ -80,6 +80,13 @@ impl SupportedAPIs {
|
|||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ pub enum ProviderId {
|
|||
Anthropic,
|
||||
GitHub,
|
||||
Arch,
|
||||
AzureOpenAI,
|
||||
XAI,
|
||||
TogetherAI,
|
||||
LambdaAI,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
|
|
@ -26,6 +30,10 @@ impl From<&str> for ProviderId {
|
|||
"anthropic" => ProviderId::Anthropic,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
"azure_openai" => ProviderId::AzureOpenAI,
|
||||
"xai" => ProviderId::XAI,
|
||||
"together_ai" => ProviderId::TogetherAI,
|
||||
"lambda_ai" => ProviderId::LambdaAI,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
|
|
@ -40,8 +48,31 @@ impl ProviderId {
|
|||
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::LambdaAI,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::LambdaAI,
|
||||
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -57,6 +88,10 @@ impl Display for ProviderId {
|
|||
ProviderId::Anthropic => write!(f, "Anthropic"),
|
||||
ProviderId::GitHub => write!(f, "GitHub"),
|
||||
ProviderId::Arch => write!(f, "Arch"),
|
||||
ProviderId::AzureOpenAI => write!(f, "azure_openai"),
|
||||
ProviderId::XAI => write!(f, "xai"),
|
||||
ProviderId::TogetherAI => write!(f, "together_ai"),
|
||||
ProviderId::LambdaAI => write!(f, "lambda_ai"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,8 +98,14 @@ impl StreamContext {
|
|||
fn update_upstream_path(&mut self, request_path: &str) {
|
||||
let hermes_provider_id = self.llm_provider().to_provider_id();
|
||||
if let Some(api) = &self.client_api {
|
||||
let target_endpoint =
|
||||
api.target_endpoint_for_provider(&hermes_provider_id, request_path);
|
||||
let target_endpoint = api.target_endpoint_for_provider(
|
||||
&hermes_provider_id,
|
||||
request_path,
|
||||
self.llm_provider()
|
||||
.model
|
||||
.as_ref()
|
||||
.unwrap_or(&"".to_string()),
|
||||
);
|
||||
if target_endpoint != request_path {
|
||||
self.set_http_request_header(":path", Some(&target_endpoint));
|
||||
}
|
||||
|
|
@ -622,7 +628,12 @@ impl HttpContext for StreamContext {
|
|||
if self.llm_provider().endpoint.is_some() {
|
||||
self.add_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
&self.llm_provider().name.to_string(),
|
||||
&self
|
||||
.llm_provider()
|
||||
.cluster_name
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
} else {
|
||||
self.add_http_request_header(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue