mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
feat: add provider arbitrage policy and fallback routing
This commit is contained in:
parent
de2d8847f3
commit
07ad4c6ae2
10 changed files with 670 additions and 57 deletions
|
|
@ -191,6 +191,7 @@ def validate_and_render_schema():
|
|||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
arbitrage_rank_validations = []
|
||||
|
||||
print("listeners: ", listeners)
|
||||
|
||||
|
|
@ -254,6 +255,30 @@ def validate_and_render_schema():
|
|||
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
arbitrage_policy = model_provider.get("arbitrage_policy")
|
||||
if arbitrage_policy:
|
||||
arbitrage_enabled = arbitrage_policy.get("enabled", False)
|
||||
arbitrage_rank = arbitrage_policy.get("rank", [])
|
||||
|
||||
if arbitrage_enabled and len(arbitrage_rank) == 0:
|
||||
raise Exception(
|
||||
f"Model {model_name} has arbitrage_policy.enabled=true but rank is empty. Please provide at least one ranked candidate."
|
||||
)
|
||||
|
||||
if arbitrage_enabled and is_wildcard:
|
||||
raise Exception(
|
||||
f"Model {model_name} has arbitrage_policy.enabled=true but uses wildcard (*). Arbitrage policy requires deterministic model candidates."
|
||||
)
|
||||
|
||||
if len(arbitrage_rank) != len(set(arbitrage_rank)):
|
||||
raise Exception(
|
||||
f"Model {model_name} has duplicate entries in arbitrage_policy.rank. Please provide each candidate once in ranked order."
|
||||
)
|
||||
|
||||
if arbitrage_enabled:
|
||||
provider_label = model_provider.get("name") or model_name
|
||||
arbitrage_rank_validations.append((provider_label, arbitrage_rank))
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
"base_url"
|
||||
|
|
@ -417,6 +442,16 @@ def validate_and_render_schema():
|
|||
}
|
||||
)
|
||||
|
||||
arbitrage_allowed_targets = model_name_keys.union(model_provider_name_set)
|
||||
for provider_name, rank in arbitrage_rank_validations:
|
||||
for ranked_candidate in rank:
|
||||
if ranked_candidate not in arbitrage_allowed_targets:
|
||||
raise Exception(
|
||||
f"Model provider '{provider_name}' has arbitrage_policy.rank candidate '{ranked_candidate}' "
|
||||
"that is not defined in model_providers. "
|
||||
"Use a configured provider name, model id, or provider/model slug."
|
||||
)
|
||||
|
||||
config_yaml["model_providers"] = deepcopy(updated_model_providers)
|
||||
|
||||
listeners_with_provider = 0
|
||||
|
|
|
|||
|
|
@ -289,6 +289,107 @@ llm_providers:
|
|||
tracing:
|
||||
random_sampling: 100
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "arbitrage_policy_enabled_requires_non_empty_rank",
|
||||
"expected_error": "arbitrage_policy.enabled=true but rank is empty",
|
||||
"plano_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank: []
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "arbitrage_policy_rank_candidate_must_exist",
|
||||
"expected_error": "arbitrage_policy.rank candidate 'openai/not-configured'",
|
||||
"plano_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank:
|
||||
- openai/not-configured
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "arbitrage_policy_rejects_duplicate_rank_entries",
|
||||
"expected_error": "duplicate entries in arbitrage_policy.rank",
|
||||
"plano_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank:
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4o-mini
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "arbitrage_policy_valid_rank",
|
||||
"expected_error": None,
|
||||
"plano_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
||||
- model: groq/llama-3.1-8b-instant
|
||||
access_key: $GROQ_API_KEY
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank:
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4o
|
||||
on_failure:
|
||||
fallback_to_primary: true
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
|
|
|||
|
|
@ -193,6 +193,22 @@ properties:
|
|||
required:
|
||||
- name
|
||||
- description
|
||||
arbitrage_policy:
|
||||
type: object
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
rank:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
on_failure:
|
||||
type: object
|
||||
properties:
|
||||
fallback_to_primary:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
|
|
@ -240,6 +256,22 @@ properties:
|
|||
required:
|
||||
- name
|
||||
- description
|
||||
arbitrage_policy:
|
||||
type: object
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
rank:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
on_failure:
|
||||
type: object
|
||||
properties:
|
||||
fallback_to_primary:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
|
|
|
|||
|
|
@ -29,11 +29,56 @@ use crate::state::{
|
|||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component,
|
||||
routing as tracing_routing, set_service_name,
|
||||
};
|
||||
|
||||
use common::errors::BrightStaffError;
|
||||
|
||||
fn strip_provider_prefix(model: &str) -> String {
|
||||
if let Some((_, model_name)) = model.split_once('/') {
|
||||
model_name.to_string()
|
||||
} else {
|
||||
model.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn is_retryable_upstream_status(status: reqwest::StatusCode) -> bool {
|
||||
matches!(
|
||||
status,
|
||||
reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
| reqwest::StatusCode::BAD_GATEWAY
|
||||
| reqwest::StatusCode::SERVICE_UNAVAILABLE
|
||||
| reqwest::StatusCode::GATEWAY_TIMEOUT
|
||||
)
|
||||
}
|
||||
|
||||
async fn build_arbitrage_candidate_chain(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
primary_model: &str,
|
||||
) -> Vec<String> {
|
||||
let mut chain = Vec::new();
|
||||
let providers = llm_providers.read().await;
|
||||
|
||||
if let Some(provider) = providers.get(primary_model) {
|
||||
if let Some(arbitrage_policy) = &provider.arbitrage_policy {
|
||||
if arbitrage_policy.enabled.unwrap_or(false) {
|
||||
for ranked_candidate in arbitrage_policy.rank.clone().unwrap_or_default() {
|
||||
if !chain.contains(&ranked_candidate) {
|
||||
chain.push(ranked_candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !chain.contains(&primary_model.to_string()) {
|
||||
chain.push(primary_model.to_string());
|
||||
}
|
||||
|
||||
chain
|
||||
}
|
||||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
|
|
@ -97,7 +142,7 @@ async fn llm_chat_inner(
|
|||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
request_headers: hyper::HeaderMap,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -331,9 +376,6 @@ async fn llm_chat_inner(
|
|||
}
|
||||
}
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
// This gets its own span for latency and error tracking
|
||||
let routing_span = info_span!(
|
||||
|
|
@ -374,73 +416,209 @@ async fn llm_chat_inner(
|
|||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
let primary_model = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
let arbitrage_chain = build_arbitrage_candidate_chain(&llm_providers, &primary_model).await;
|
||||
let mut selected_model = primary_model.clone();
|
||||
let mut llm_response: Option<reqwest::Response> = None;
|
||||
let mut last_transport_error: Option<String> = None;
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let http_client = reqwest::Client::new();
|
||||
|
||||
let span_name = if model_from_request == resolved_model {
|
||||
format!("POST {} {}", request_path, resolved_model)
|
||||
for (attempt_idx, candidate_model) in arbitrage_chain.iter().enumerate() {
|
||||
selected_model = candidate_model.clone();
|
||||
let candidate_model_name = strip_provider_prefix(candidate_model);
|
||||
let mut candidate_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
candidate = %candidate_model,
|
||||
error = %err,
|
||||
"failed to build candidate request"
|
||||
);
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
candidate_request.set_model(candidate_model_name.clone());
|
||||
if candidate_request.remove_metadata_key("plano_preference_config") {
|
||||
debug!("removed plano_preference_config from candidate metadata");
|
||||
}
|
||||
|
||||
let (candidate_provider_id, _) = get_provider_info(&llm_providers, candidate_model).await;
|
||||
let selection_reason = if attempt_idx == 0 {
|
||||
if candidate_model == &primary_model {
|
||||
"router_selected_primary"
|
||||
} else {
|
||||
"free_tier_available"
|
||||
}
|
||||
} else {
|
||||
"fallback_on_retryable_error"
|
||||
};
|
||||
let is_fallback_attempt = attempt_idx > 0;
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_routing::SELECTION_REASON,
|
||||
selection_reason.to_string(),
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_routing::IS_FALLBACK,
|
||||
is_fallback_attempt,
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_llm::PROVIDER,
|
||||
candidate_provider_id.to_string(),
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.attempt_index",
|
||||
(attempt_idx + 1) as i64,
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.attempt_total",
|
||||
arbitrage_chain.len() as i64,
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_routing::UPSTREAM_ENDPOINT,
|
||||
candidate_model.clone(),
|
||||
));
|
||||
});
|
||||
if let Some(ref client_api_kind) = client_api {
|
||||
let upstream_api = candidate_provider_id
|
||||
.compatible_api_for_client(client_api_kind, is_streaming_request);
|
||||
candidate_request.normalize_for_upstream(candidate_provider_id, &upstream_api);
|
||||
}
|
||||
let candidate_request_bytes = ProviderRequestType::to_bytes(&candidate_request).unwrap();
|
||||
|
||||
let mut candidate_headers = request_headers.clone();
|
||||
candidate_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(candidate_model).unwrap(),
|
||||
);
|
||||
candidate_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
candidate_headers.remove(header::CONTENT_LENGTH);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut candidate_headers));
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %candidate_model,
|
||||
upstream_model = %candidate_model_name,
|
||||
selection_reason = %selection_reason,
|
||||
attempt_index = attempt_idx + 1,
|
||||
attempt_total = arbitrage_chain.len(),
|
||||
"Routing candidate to upstream"
|
||||
);
|
||||
|
||||
let response = match http_client
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(candidate_headers)
|
||||
.body(candidate_request_bytes)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
last_transport_error = Some(err.to_string());
|
||||
if attempt_idx + 1 < arbitrage_chain.len() {
|
||||
let next_candidate = arbitrage_chain[attempt_idx + 1].as_str();
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.fallback_trigger",
|
||||
"transport_error".to_string(),
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.next_candidate",
|
||||
next_candidate.to_string(),
|
||||
));
|
||||
});
|
||||
warn!(
|
||||
candidate = %candidate_model,
|
||||
error = %err,
|
||||
next_candidate = %next_candidate,
|
||||
attempt_index = attempt_idx + 1,
|
||||
attempt_total = arbitrage_chain.len(),
|
||||
"candidate transport failure, trying next fallback"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
return Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to send request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
if is_retryable_upstream_status(status) && attempt_idx + 1 < arbitrage_chain.len() {
|
||||
let next_candidate = arbitrage_chain[attempt_idx + 1].as_str();
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.fallback_trigger",
|
||||
format!("http_{}", status.as_u16()),
|
||||
));
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
"routing.next_candidate",
|
||||
next_candidate.to_string(),
|
||||
));
|
||||
});
|
||||
warn!(
|
||||
candidate = %candidate_model,
|
||||
status = status.as_u16(),
|
||||
next_candidate = %next_candidate,
|
||||
attempt_index = attempt_idx + 1,
|
||||
attempt_total = arbitrage_chain.len(),
|
||||
"candidate returned retryable status, trying next fallback"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
llm_response = Some(response);
|
||||
break;
|
||||
}
|
||||
|
||||
let llm_response = match llm_response {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
return Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to send request across arbitrage chain: {}",
|
||||
last_transport_error.unwrap_or_else(|| "unknown error".to_string())
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, selected_model.as_str());
|
||||
let span_name = if model_from_request == selected_model {
|
||||
format!("POST {} {}", request_path, selected_model)
|
||||
} else {
|
||||
format!(
|
||||
"POST {} {} -> {}",
|
||||
request_path, model_from_request, resolved_model
|
||||
request_path, model_from_request, selected_model
|
||||
)
|
||||
};
|
||||
get_active_span(|span| {
|
||||
span.update_name(span_name.clone());
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %resolved_model,
|
||||
upstream_model = %model_name_only,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&resolved_model).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to send request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
|
|
@ -480,7 +658,7 @@ async fn llm_chat_inner(
|
|||
state_store,
|
||||
original_input_items,
|
||||
alias_resolved_model.clone(),
|
||||
resolved_model.clone(),
|
||||
selected_model.clone(),
|
||||
is_streaming_request,
|
||||
false, // Not OpenAI upstream since should_manage_state is true
|
||||
content_encoding,
|
||||
|
|
@ -570,3 +748,101 @@ async fn get_provider_info(
|
|||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
build_arbitrage_candidate_chain, is_retryable_upstream_status, strip_provider_prefix,
|
||||
};
|
||||
use common::configuration::{
|
||||
ArbitrageFailurePolicy, ArbitragePolicy, LlmProvider, LlmProviderType,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
fn provider(name: &str, model: &str, default: bool) -> LlmProvider {
|
||||
LlmProvider {
|
||||
name: name.to_string(),
|
||||
model: Some(model.to_string()),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
default: Some(default),
|
||||
stream: None,
|
||||
access_key: None,
|
||||
endpoint: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
cluster_name: None,
|
||||
base_url_path_prefix: None,
|
||||
internal: None,
|
||||
passthrough_auth: None,
|
||||
arbitrage_policy: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_provider_prefix() {
|
||||
assert_eq!(strip_provider_prefix("openai/gpt-4o-mini"), "gpt-4o-mini");
|
||||
assert_eq!(strip_provider_prefix("gpt-4o-mini"), "gpt-4o-mini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_status_matrix_is_deterministic() {
|
||||
assert!(is_retryable_upstream_status(
|
||||
reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
));
|
||||
assert!(is_retryable_upstream_status(
|
||||
reqwest::StatusCode::BAD_GATEWAY
|
||||
));
|
||||
assert!(is_retryable_upstream_status(
|
||||
reqwest::StatusCode::SERVICE_UNAVAILABLE
|
||||
));
|
||||
assert!(is_retryable_upstream_status(
|
||||
reqwest::StatusCode::GATEWAY_TIMEOUT
|
||||
));
|
||||
assert!(!is_retryable_upstream_status(
|
||||
reqwest::StatusCode::BAD_REQUEST
|
||||
));
|
||||
assert!(!is_retryable_upstream_status(
|
||||
reqwest::StatusCode::UNAUTHORIZED
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn arbitrage_chain_is_ranked_then_primary() {
|
||||
let mut primary = provider("openai/gpt-4o-mini", "gpt-4o-mini", true);
|
||||
primary.arbitrage_policy = Some(ArbitragePolicy {
|
||||
enabled: Some(true),
|
||||
rank: Some(vec![
|
||||
"groq/llama-3.1-8b-instant".to_string(),
|
||||
"together_ai/openai/gpt-oss-20b".to_string(),
|
||||
]),
|
||||
on_failure: Some(ArbitrageFailurePolicy {
|
||||
fallback_to_primary: Some(true),
|
||||
}),
|
||||
});
|
||||
|
||||
let providers = vec![
|
||||
primary,
|
||||
provider("groq/llama-3.1-8b-instant", "llama-3.1-8b-instant", false),
|
||||
provider(
|
||||
"together_ai/openai/gpt-oss-20b",
|
||||
"openai/gpt-oss-20b",
|
||||
false,
|
||||
),
|
||||
];
|
||||
let llm_providers = Arc::new(RwLock::new(LlmProviders::try_from(providers).unwrap()));
|
||||
|
||||
let chain = build_arbitrage_candidate_chain(&llm_providers, "openai/gpt-4o-mini").await;
|
||||
assert_eq!(
|
||||
chain,
|
||||
vec![
|
||||
"groq/llama-3.1-8b-instant".to_string(),
|
||||
"together_ai/openai/gpt-oss-20b".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -277,6 +277,18 @@ pub struct RoutingPreference {
|
|||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ArbitrageFailurePolicy {
|
||||
pub fallback_to_primary: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ArbitragePolicy {
|
||||
pub enabled: Option<bool>,
|
||||
pub rank: Option<Vec<String>>,
|
||||
pub on_failure: Option<ArbitrageFailurePolicy>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AgentUsagePreference {
|
||||
pub model: String,
|
||||
|
|
@ -331,6 +343,7 @@ pub struct LlmProvider {
|
|||
pub base_url_path_prefix: Option<String>,
|
||||
pub internal: Option<bool>,
|
||||
pub passthrough_auth: Option<bool>,
|
||||
pub arbitrage_policy: Option<ArbitragePolicy>,
|
||||
}
|
||||
|
||||
pub trait IntoModels {
|
||||
|
|
@ -375,6 +388,7 @@ impl Default for LlmProvider {
|
|||
base_url_path_prefix: None,
|
||||
internal: None,
|
||||
passthrough_auth: None,
|
||||
arbitrage_policy: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ mod tests {
|
|||
internal: None,
|
||||
stream: None,
|
||||
passthrough_auth: None,
|
||||
arbitrage_policy: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
38
demos/llm_routing/gpu_free_tier_arbitrage/README.md
Normal file
38
demos/llm_routing/gpu_free_tier_arbitrage/README.md
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# GPU Free-Tier Arbitrage Demo
|
||||
|
||||
This demo package showcases provider-level free-tier-first routing and deterministic fallback using a local Plano endpoint on `localhost:12000`.
|
||||
|
||||
## Files
|
||||
|
||||
- `config.yaml` - demo Plano config with `arbitrage_policy`
|
||||
- `demo.rest` - runnable REST requests for IDE REST clients
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Set API keys for providers used in this demo:
|
||||
|
||||
- `OPENAI_API_KEY`
|
||||
- `GROQ_API_KEY`
|
||||
- `TOGETHER_API_KEY`
|
||||
|
||||
## Run the demo
|
||||
|
||||
From this directory:
|
||||
|
||||
```bash
|
||||
planoai up config.yaml
|
||||
```
|
||||
|
||||
Then run requests from `demo.rest` in your REST client.
|
||||
|
||||
## What to show during the demo
|
||||
|
||||
1. Run `free-tier-first showcase` and verify response success.
|
||||
2. Inspect logs/traces for provider selection reason and selected candidate.
|
||||
3. Force a retryable error on the first candidate (for example, temporarily invalid key), then run `fallback showcase`.
|
||||
4. Verify fallback metadata appears in traces/logs:
|
||||
- `routing.selection_reason`
|
||||
- `routing.is_fallback`
|
||||
- `routing.fallback_trigger`
|
||||
- `routing.next_candidate`
|
||||
- `routing.upstream_endpoint`
|
||||
30
demos/llm_routing/gpu_free_tier_arbitrage/config.yaml
Normal file
30
demos/llm_routing/gpu_free_tier_arbitrage/config.yaml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
max_retries: 1
|
||||
|
||||
model_providers:
|
||||
# Primary provider for the model.
|
||||
- model: openai/gpt-5.2
|
||||
# This is a failure key to test the arbitrage policy
|
||||
access_key: $OPENAI_API_KEY_FAILURE
|
||||
default: true
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank:
|
||||
# Demo low-cost/free-tier candidates (ordered).
|
||||
- ollama/qwen3:8b
|
||||
- groq/llama-3.1-8b-instant
|
||||
|
||||
# Candidates referenced by arbitrage_policy.rank.
|
||||
- model: groq/llama-3.1-8b-instant
|
||||
access_key: $GROQ_API_KEY
|
||||
|
||||
- model: ollama/qwen3:8b
|
||||
base_url: http://localhost:11434
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
31
demos/llm_routing/gpu_free_tier_arbitrage/demo.rest
Normal file
31
demos/llm_routing/gpu_free_tier_arbitrage/demo.rest
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
@llm_endpoint = http://localhost:12000
|
||||
|
||||
### free-tier-first showcase
|
||||
POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-5.2",
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Reply with exactly: free-tier-first routing demo successful."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
### fallback showcase (run after forcing first candidate failure)
|
||||
POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-5.2",
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Reply with exactly: fallback routing demo successful."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -430,6 +430,61 @@ Here are common scenarios where Arch-Router excels:
|
|||
|
||||
- **Conversational Routing**: Track conversation context to identify when topics shift between domains or when the type of assistance needed changes mid-conversation.
|
||||
|
||||
GPU Free-Tier Arbitrage
|
||||
-----------------------
|
||||
|
||||
Plano can apply a provider-level arbitrage policy so low-stakes or bursty traffic tries free/low-cost providers first, then deterministically falls back to the primary provider when retryable failures occur.
|
||||
|
||||
Arbitrage policy config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Define ``arbitrage_policy`` on the primary provider:
|
||||
|
||||
.. code-block:: yaml
|
||||
:caption: Free-tier-first Arbitrage Configuration
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
arbitrage_policy:
|
||||
enabled: true
|
||||
rank:
|
||||
- groq/llama-3.1-8b-instant
|
||||
- together_ai/openai/gpt-oss-20b
|
||||
on_failure:
|
||||
fallback_to_primary: true
|
||||
|
||||
- model: groq/llama-3.1-8b-instant
|
||||
access_key: $GROQ_API_KEY
|
||||
|
||||
- model: together_ai/openai/gpt-oss-20b
|
||||
access_key: $TOGETHER_API_KEY
|
||||
|
||||
Deterministic fallback behavior
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Candidate chain is evaluated in order: ``rank`` entries, then the primary provider
|
||||
- Retryable failures trigger fallback: transport errors, HTTP ``429``, ``502``, ``503``, ``504``
|
||||
- Non-retryable failures stop the chain immediately
|
||||
- If all candidates fail, Plano returns an explicit error (no silent degradation)
|
||||
|
||||
Trace visibility for each decision
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Every attempt emits structured decision metadata. At minimum, inspect:
|
||||
|
||||
- ``llm.model`` and ``llm.provider`` for the selected upstream at each hop
|
||||
- ``routing.selection_reason`` (for example ``free_tier_available`` or ``fallback_on_retryable_error``)
|
||||
- ``routing.is_fallback`` to identify fallback attempts
|
||||
- ``routing.fallback_trigger`` and ``routing.next_candidate`` when a retryable failure causes fallback
|
||||
- ``routing.upstream_endpoint`` for the selected candidate in each attempt
|
||||
|
||||
You can run a local showcase with:
|
||||
|
||||
- ``demos/llm_routing/gpu_free_tier_arbitrage/config.yaml``
|
||||
- ``demos/llm_routing/gpu_free_tier_arbitrage/demo.rest``
|
||||
|
||||
Best practices
|
||||
--------------
|
||||
- **💡Consistent Naming:** Route names should align with their descriptions.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue