mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
merge main
This commit is contained in:
commit
692499d910
22 changed files with 1771 additions and 215 deletions
|
|
@ -135,13 +135,27 @@ async fn llm_chat_inner(
|
|||
}
|
||||
};
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
let raw_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&chat_request_bytes),
|
||||
body = %String::from_utf8_lossy(&raw_bytes),
|
||||
"request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy) =
|
||||
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
|
|
@ -193,6 +207,7 @@ async fn llm_chat_inner(
|
|||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
let (provider_id, _) = get_provider_info(&llm_providers, &alias_resolved_model).await;
|
||||
|
||||
// Validate that the requested model exists in configuration
|
||||
// This matches the validation in llm_gateway routing.rs
|
||||
|
|
@ -330,6 +345,11 @@ async fn llm_chat_inner(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(ref client_api_kind) = client_api {
|
||||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
|
||||
client_request.normalize_for_upstream(provider_id, &upstream_api);
|
||||
}
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
|
|
@ -429,6 +449,7 @@ async fn llm_chat_inner(
|
|||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -575,7 +596,6 @@ async fn llm_chat_inner(
|
|||
.into_response()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
|
@ -76,16 +77,21 @@ pub async fn router_chat_get_upstream_model(
|
|||
"router request"
|
||||
);
|
||||
|
||||
// Extract usage preferences from metadata
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
// Use inline preferences if provided, otherwise fall back to metadata extraction
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
|
||||
{
|
||||
inline_usage_preferences
|
||||
} else {
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok())
|
||||
};
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::SpanAttributes;
|
||||
use common::configuration::{ModelUsagePreference, SpanAttributes};
|
||||
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
|
|
@ -14,6 +14,53 @@ use crate::handlers::router_chat::router_chat_get_upstream_model;
|
|||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||
|
||||
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
|
||||
/// and parsed preferences. The `routing_policy` field is removed from the JSON
|
||||
/// before re-serializing so downstream parsers don't see the non-standard field.
|
||||
///
|
||||
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
|
||||
pub fn extract_routing_policy(
|
||||
raw_bytes: &[u8],
|
||||
warn_on_size: bool,
|
||||
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
|
||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||
|
||||
let preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.remove("routing_policy"))
|
||||
.and_then(|policy_value| {
|
||||
if warn_on_size {
|
||||
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
|
||||
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
|
||||
warn!(
|
||||
size_bytes = policy_str.len(),
|
||||
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
|
||||
"routing_policy exceeds recommended size limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
|
||||
Ok(prefs) => {
|
||||
info!(
|
||||
num_models = prefs.len(),
|
||||
"using inline routing_policy from request body"
|
||||
);
|
||||
Some(prefs)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse routing_policy");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
Ok((bytes, preferences))
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RoutingDecisionResponse {
|
||||
model: String,
|
||||
|
|
@ -98,13 +145,26 @@ async fn routing_decision_inner(
|
|||
.to_string();
|
||||
|
||||
// Parse request body
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
let raw_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&chat_request_bytes),
|
||||
body = %String::from_utf8_lossy(&raw_bytes),
|
||||
"routing decision request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
|
|
@ -120,13 +180,14 @@ async fn routing_decision_inner(
|
|||
}
|
||||
};
|
||||
|
||||
// Call the existing routing logic
|
||||
// Call the existing routing logic with inline preferences
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_preferences,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -161,3 +222,136 @@ async fn routing_decision_inner(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
|
||||
let extra = if extra_fields.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(", {}", extra_fields)
|
||||
};
|
||||
format!(
|
||||
r#"{{"model": "gpt-4o-mini", "messages": [{{"role": "user", "content": "hello"}}]{}}}"#,
|
||||
extra
|
||||
)
|
||||
.into_bytes()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_no_policy() {
|
||||
let body = make_chat_body("");
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_valid_policy() {
|
||||
let policy = r#""routing_policy": [
|
||||
{
|
||||
"model": "openai/gpt-4o",
|
||||
"routing_preferences": [
|
||||
{"name": "coding", "description": "code generation tasks"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"routing_preferences": [
|
||||
{"name": "general", "description": "general questions"}
|
||||
]
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should have parsed preferences");
|
||||
assert_eq!(prefs.len(), 2);
|
||||
assert_eq!(prefs[0].model, "openai/gpt-4o");
|
||||
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
|
||||
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
|
||||
assert_eq!(prefs[1].routing_preferences[0].name, "general");
|
||||
|
||||
// routing_policy should be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_policy_returns_none() {
|
||||
// routing_policy is present but has wrong shape
|
||||
let policy = r#""routing_policy": "not-an-array""#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
// Invalid policy should be ignored (returns None), not error
|
||||
assert!(prefs.is_none());
|
||||
// routing_policy should still be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_json_returns_error() {
|
||||
let body = b"not valid json";
|
||||
let result = extract_routing_policy(body, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Failed to parse JSON"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_empty_array() {
|
||||
let policy = r#""routing_policy": []"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("empty array is valid");
|
||||
assert_eq!(prefs.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
route: Some("code_generation".to_string()),
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "openai/gpt-4o");
|
||||
assert_eq!(parsed["route"], "code_generation");
|
||||
assert_eq!(parsed["trace_id"], "abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization_no_route() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "none".to_string(),
|
||||
route: None,
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "none");
|
||||
assert!(parsed["route"].is_null());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
|
|||
}]),
|
||||
})]
|
||||
}
|
||||
InputParam::SingleItem(item) => vec![item.clone()],
|
||||
InputParam::Items(items) => items.clone(),
|
||||
}
|
||||
}
|
||||
|
|
@ -146,3 +147,101 @@ pub async fn retrieve_and_combine_input(
|
|||
let combined_input = storage.merge(&prev_state, current_input);
|
||||
Ok(combined_input)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::extract_input_items;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_extract_input_items_converts_text_to_user_message_item() {
|
||||
let extracted = extract_input_items(&InputParam::Text("hello world".to_string()));
|
||||
assert_eq!(extracted.len(), 1);
|
||||
|
||||
let InputItem::Message(message) = &extracted[0] else {
|
||||
panic!("expected InputItem::Message");
|
||||
};
|
||||
assert!(matches!(message.role, MessageRole::User));
|
||||
|
||||
let MessageContent::Items(items) = &message.content else {
|
||||
panic!("expected MessageContent::Items");
|
||||
};
|
||||
assert_eq!(items.len(), 1);
|
||||
|
||||
let InputContent::InputText { text } = &items[0] else {
|
||||
panic!("expected InputContent::InputText");
|
||||
};
|
||||
assert_eq!(text, "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_input_items_preserves_single_item() {
|
||||
let item = InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "assistant note".to_string(),
|
||||
}]),
|
||||
});
|
||||
|
||||
let extracted = extract_input_items(&InputParam::SingleItem(item.clone()));
|
||||
assert_eq!(extracted.len(), 1);
|
||||
let InputItem::Message(message) = &extracted[0] else {
|
||||
panic!("expected InputItem::Message");
|
||||
};
|
||||
assert!(matches!(message.role, MessageRole::Assistant));
|
||||
let MessageContent::Items(items) = &message.content else {
|
||||
panic!("expected MessageContent::Items");
|
||||
};
|
||||
let InputContent::InputText { text } = &items[0] else {
|
||||
panic!("expected InputContent::InputText");
|
||||
};
|
||||
assert_eq!(text, "assistant note");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_input_items_preserves_items_list() {
|
||||
let items = vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "first".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "second".to_string(),
|
||||
}]),
|
||||
}),
|
||||
];
|
||||
|
||||
let extracted = extract_input_items(&InputParam::Items(items.clone()));
|
||||
assert_eq!(extracted.len(), items.len());
|
||||
|
||||
let InputItem::Message(first) = &extracted[0] else {
|
||||
panic!("expected first item to be message");
|
||||
};
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
let MessageContent::Items(first_items) = &first.content else {
|
||||
panic!("expected MessageContent::Items");
|
||||
};
|
||||
let InputContent::InputText { text: first_text } = &first_items[0] else {
|
||||
panic!("expected InputContent::InputText");
|
||||
};
|
||||
assert_eq!(first_text, "first");
|
||||
|
||||
let InputItem::Message(second) = &extracted[1] else {
|
||||
panic!("expected second item to be message");
|
||||
};
|
||||
assert!(matches!(second.role, MessageRole::Assistant));
|
||||
let MessageContent::Items(second_items) = &second.content else {
|
||||
panic!("expected MessageContent::Items");
|
||||
};
|
||||
let InputContent::InputText { text: second_text } = &second_items[0] else {
|
||||
panic!("expected InputContent::InputText");
|
||||
};
|
||||
assert_eq!(second_text, "second");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue