mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 16:52:41 +02:00
Model affinity for consistent model selection in agentic loops (#827)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
This commit is contained in:
parent
978b1ea722
commit
8dedf0bec1
13 changed files with 614 additions and 43 deletions
|
|
@ -1,6 +1,6 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{FilterPipeline, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, MODEL_AFFINITY_HEADER};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
|
|
@ -94,6 +94,21 @@ async fn llm_chat_inner(
|
|||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
// Session pinning: extract session ID and check cache before routing
|
||||
let session_id: Option<String> = request_headers
|
||||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.get_cached_route(sid)
|
||||
.await
|
||||
.map(|c| c.model_name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
// --- Phase 1: Parse and validate the incoming request ---
|
||||
|
|
@ -244,46 +259,65 @@ async fn llm_chat_inner(
|
|||
}
|
||||
};
|
||||
|
||||
// --- Phase 3: Route the request ---
|
||||
let routing_span = info_span!(
|
||||
"routing",
|
||||
component = "routing",
|
||||
http.method = "POST",
|
||||
http.target = %request_path,
|
||||
model.requested = %model_from_request,
|
||||
model.alias_resolved = %alias_resolved_model,
|
||||
route.selected_model = tracing::field::Empty,
|
||||
routing.determination_ms = tracing::field::Empty,
|
||||
);
|
||||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
Arc::clone(&state.router_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(routing_span)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
let mut internal_error = Response::new(full(err.message));
|
||||
*internal_error.status_mut() = err.status_code;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine final model (router returns "none" when it doesn't select a specific model)
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
router_selected_model
|
||||
// --- Phase 3: Route the request (or use pinned model from session cache) ---
|
||||
let resolved_model = if let Some(cached_model) = pinned_model {
|
||||
info!(
|
||||
session_id = %session_id.as_deref().unwrap_or(""),
|
||||
model = %cached_model,
|
||||
"using pinned routing decision from cache"
|
||||
);
|
||||
cached_model
|
||||
} else {
|
||||
alias_resolved_model.clone()
|
||||
let routing_span = info_span!(
|
||||
"routing",
|
||||
component = "routing",
|
||||
http.method = "POST",
|
||||
http.target = %request_path,
|
||||
model.requested = %model_from_request,
|
||||
model.alias_resolved = %alias_resolved_model,
|
||||
route.selected_model = tracing::field::Empty,
|
||||
routing.determination_ms = tracing::field::Empty,
|
||||
);
|
||||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
Arc::clone(&state.router_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(routing_span)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
let mut internal_error = Response::new(full(err.message));
|
||||
*internal_error.status_mut() = err.status_code;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let (router_selected_model, route_name) =
|
||||
(routing_result.model_name, routing_result.route_name);
|
||||
let model = if router_selected_model != "none" {
|
||||
router_selected_model
|
||||
} else {
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Cache the routing decision so subsequent requests with the same session ID are pinned
|
||||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.cache_route(sid.clone(), model.clone(), route_name)
|
||||
.await;
|
||||
}
|
||||
|
||||
model
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::consts::{MODEL_AFFINITY_HEADER, REQUEST_ID_HEADER};
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::ProviderRequestType;
|
||||
|
|
@ -53,6 +53,9 @@ struct RoutingDecisionResponse {
|
|||
models: Vec<String>,
|
||||
route: Option<String>,
|
||||
trace_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
pinned: bool,
|
||||
}
|
||||
|
||||
pub async fn routing_decision(
|
||||
|
|
@ -68,6 +71,11 @@ pub async fn routing_decision(
|
|||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let session_id: Option<String> = request_headers
|
||||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
|
|
@ -85,6 +93,7 @@ pub async fn routing_decision(
|
|||
request_path,
|
||||
request_headers,
|
||||
custom_attrs,
|
||||
session_id,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
|
|
@ -97,6 +106,7 @@ async fn routing_decision_inner(
|
|||
request_path: String,
|
||||
request_headers: hyper::HeaderMap,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
session_id: Option<String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
opentelemetry::trace::get_active_span(|span| {
|
||||
|
|
@ -114,6 +124,34 @@ async fn routing_decision_inner(
|
|||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Session pinning: check cache before doing any routing work
|
||||
if let Some(ref sid) = session_id {
|
||||
if let Some(cached) = router_service.get_cached_route(sid).await {
|
||||
info!(
|
||||
session_id = %sid,
|
||||
model = %cached.model_name,
|
||||
route = ?cached.route_name,
|
||||
"returning pinned routing decision from cache"
|
||||
);
|
||||
let response = RoutingDecisionResponse {
|
||||
models: vec![cached.model_name],
|
||||
route: cached.route_name,
|
||||
trace_id,
|
||||
session_id: Some(sid.clone()),
|
||||
pinned: true,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
let raw_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -152,7 +190,7 @@ async fn routing_decision_inner(
|
|||
};
|
||||
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
router_service,
|
||||
Arc::clone(&router_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
|
|
@ -163,10 +201,23 @@ async fn routing_decision_inner(
|
|||
|
||||
match routing_result {
|
||||
Ok(result) => {
|
||||
// Cache the result if session_id is present
|
||||
if let Some(ref sid) = session_id {
|
||||
router_service
|
||||
.cache_route(
|
||||
sid.clone(),
|
||||
result.model_name.clone(),
|
||||
result.route_name.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let response = RoutingDecisionResponse {
|
||||
models: result.models,
|
||||
route: result.route_name,
|
||||
trace_id,
|
||||
session_id,
|
||||
pinned: false,
|
||||
};
|
||||
|
||||
info!(
|
||||
|
|
@ -329,6 +380,8 @@ mod tests {
|
|||
],
|
||||
route: Some("code_generation".to_string()),
|
||||
trace_id: "abc123".to_string(),
|
||||
session_id: Some("sess-abc".to_string()),
|
||||
pinned: true,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -336,6 +389,8 @@ mod tests {
|
|||
assert_eq!(parsed["models"][1], "openai/gpt-4o");
|
||||
assert_eq!(parsed["route"], "code_generation");
|
||||
assert_eq!(parsed["trace_id"], "abc123");
|
||||
assert_eq!(parsed["session_id"], "sess-abc");
|
||||
assert_eq!(parsed["pinned"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -344,10 +399,14 @@ mod tests {
|
|||
models: vec!["none".to_string()],
|
||||
route: None,
|
||||
trace_id: "abc123".to_string(),
|
||||
session_id: None,
|
||||
pinned: false,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["models"][0], "none");
|
||||
assert!(parsed["route"].is_null());
|
||||
assert!(parsed.get("session_id").is_none());
|
||||
assert_eq!(parsed["pinned"], false);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue