mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
feat: implement capability-aware routing for OpenAI Responses API, enhancing model selection and error handling for unsupported tools
This commit is contained in:
parent
eed196bc81
commit
2210f69db0
8 changed files with 189 additions and 39 deletions
|
|
@ -14,7 +14,7 @@ use hyper::{Request, Response, StatusCode};
|
|||
use opentelemetry::global;
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
|
@ -149,6 +149,8 @@ async fn llm_chat_inner(
|
|||
client_api,
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
let requires_native_responses_tools =
|
||||
responses_request_uses_non_function_tools(&client_request);
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
let model_from_request = client_request.model().to_string();
|
||||
|
|
@ -394,14 +396,44 @@ async fn llm_chat_inner(
|
|||
|
||||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let router_selected_model = routing_result.model_name.clone();
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
router_selected_model.clone()
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
let resolved_model = if requires_native_responses_tools {
|
||||
match select_capability_compatible_model(
|
||||
&llm_providers,
|
||||
&resolved_model,
|
||||
is_streaming_request,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some(compatible_model) => {
|
||||
if compatible_model != resolved_model {
|
||||
warn!(
|
||||
request_id = %request_id,
|
||||
selected_model = %resolved_model,
|
||||
compatible_model = %compatible_model,
|
||||
"selected model cannot serve responses web/file/computer tools; rerouting to compatible model"
|
||||
);
|
||||
}
|
||||
compatible_model
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No configured model can serve OpenAI Responses API requests with non-function tools".to_string();
|
||||
warn!(request_id = %request_id, error = %err_msg, "capability-aware routing failed");
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resolved_model
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
|
||||
let span_name = if model_from_request == resolved_model {
|
||||
|
|
@ -541,6 +573,81 @@ fn resolve_model_alias(
|
|||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
fn responses_request_uses_non_function_tools(client_request: &ProviderRequestType) -> bool {
|
||||
match client_request {
|
||||
ProviderRequestType::ResponsesAPIRequest(req) => req
|
||||
.tools
|
||||
.as_ref()
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| !matches!(tool, ResponsesTool::Function { .. }))
|
||||
})
|
||||
.unwrap_or(false),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn model_supports_native_responses_api(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
is_streaming: bool,
|
||||
) -> bool {
|
||||
let upstream_path = get_upstream_path(
|
||||
llm_providers,
|
||||
model_name,
|
||||
"/v1/responses",
|
||||
model_name,
|
||||
is_streaming,
|
||||
)
|
||||
.await;
|
||||
matches!(
|
||||
SupportedUpstreamAPIs::from_endpoint(&upstream_path),
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
)
|
||||
}
|
||||
|
||||
async fn select_capability_compatible_model(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
preferred_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> Option<String> {
|
||||
if model_supports_native_responses_api(llm_providers, preferred_model, is_streaming).await {
|
||||
return Some(preferred_model.to_string());
|
||||
}
|
||||
|
||||
let (default_candidate, ordered_candidates): (Option<String>, Vec<String>) = {
|
||||
let providers = llm_providers.read().await;
|
||||
let default_candidate = providers.default().map(|p| p.name.clone());
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let mut candidates = Vec::new();
|
||||
for (key, provider) in providers.iter() {
|
||||
if key != &provider.name || provider.internal == Some(true) {
|
||||
continue;
|
||||
}
|
||||
if seen.insert(provider.name.clone()) {
|
||||
candidates.push(provider.name.clone());
|
||||
}
|
||||
}
|
||||
(default_candidate, candidates)
|
||||
};
|
||||
|
||||
if let Some(default_model) = default_candidate {
|
||||
if model_supports_native_responses_api(llm_providers, &default_model, is_streaming).await {
|
||||
return Some(default_model);
|
||||
}
|
||||
}
|
||||
|
||||
for candidate in ordered_candidates {
|
||||
if model_supports_native_responses_api(llm_providers, &candidate, is_streaming).await {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
|
|
|
|||
|
|
@ -40,13 +40,16 @@ pub async fn router_chat_get_upstream_model(
|
|||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
let fallback_messages = client_request.get_messages();
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
// Convert to ChatCompletionsRequest for routing when possible.
|
||||
// If conversion fails for unsupported Responses tools (e.g. `custom`),
|
||||
// route based on normalized OpenAI messages extracted from the original request.
|
||||
let routing_messages = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req.messages,
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
|
|
@ -59,20 +62,29 @@ pub async fn router_chat_get_upstream_model(
|
|||
));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
)));
|
||||
let err_text = err.to_string();
|
||||
if err_text.contains("Unsupported conversion") {
|
||||
warn!(
|
||||
"routing conversion unsupported; falling back to routing with normalized request messages: {}",
|
||||
err_text
|
||||
);
|
||||
fallback_messages
|
||||
} else {
|
||||
warn!(
|
||||
"failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
request = %serde_json::to_string(&chat_request).unwrap(),
|
||||
"router request"
|
||||
message_count = routing_messages.len(),
|
||||
"router request prepared"
|
||||
);
|
||||
|
||||
// Extract usage preferences from metadata
|
||||
|
|
@ -87,14 +99,11 @@ pub async fn router_chat_get_upstream_model(
|
|||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content
|
||||
.as_ref()
|
||||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
let latest_message_for_log = routing_messages.last().map_or("None".to_string(), |msg| {
|
||||
msg.content
|
||||
.as_ref()
|
||||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
|
|
@ -120,7 +129,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
.determine_route(
|
||||
&chat_request.messages,
|
||||
&routing_messages,
|
||||
traceparent,
|
||||
usage_preferences,
|
||||
request_id,
|
||||
|
|
|
|||
|
|
@ -209,10 +209,13 @@ pub struct AudioConfig {
|
|||
}
|
||||
|
||||
/// Text configuration
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextConfig {
|
||||
/// Text format configuration
|
||||
pub format: TextFormat,
|
||||
pub format: Option<TextFormat>,
|
||||
/// Controls response verbosity for models that support it.
|
||||
pub verbosity: Option<String>,
|
||||
}
|
||||
|
||||
/// Text format
|
||||
|
|
@ -302,6 +305,12 @@ pub enum Tool {
|
|||
display_height_px: Option<i32>,
|
||||
display_number: Option<i32>,
|
||||
},
|
||||
/// Custom tool (forward-compatible passthrough for provider-specific tools)
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
#[serde(flatten)]
|
||||
config: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
/// Ranking options for file search
|
||||
|
|
|
|||
|
|
@ -236,7 +236,8 @@ impl ResponsesAPIStreamBuffer {
|
|||
}),
|
||||
store: Some(true),
|
||||
text: Some(TextConfig {
|
||||
format: TextFormat::Text,
|
||||
format: Some(TextFormat::Text),
|
||||
verbosity: None,
|
||||
}),
|
||||
audio: None,
|
||||
modalities: None,
|
||||
|
|
|
|||
|
|
@ -466,6 +466,9 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
|||
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::Custom { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"Custom tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()
|
||||
}).transpose()?,
|
||||
|
|
|
|||
|
|
@ -1139,6 +1139,27 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let current_time = get_current_time().unwrap();
|
||||
if end_of_stream && body_size == 0 {
|
||||
// Flush any buffered partial SSE event on stream end.
|
||||
// This handles cases where the last logical SSE event (for example, response.completed)
|
||||
// was split across chunks and the tail only arrives at end-of-stream.
|
||||
if self.streaming_response {
|
||||
let provider_id = self.get_provider_id();
|
||||
match self.handle_streaming_response(&[], provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
if !serialized_body.is_empty() {
|
||||
self.set_http_response_body(0, 0, &serialized_body);
|
||||
debug!(
|
||||
"request_id={}: flushed buffered streaming bytes on end_of_stream, size={}",
|
||||
self.request_identifier(),
|
||||
serialized_body.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Ignore flush errors and proceed with end-of-request handling.
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(
|
||||
"request_id={}: response body complete, total_bytes={}",
|
||||
self.request_identifier(),
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@ listeners:
|
|||
|
||||
model_providers:
|
||||
# OpenAI Models
|
||||
- model: openai/gpt-5-2025-08-07
|
||||
- model: openai/gpt-5.3-codex
|
||||
default: true
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
- model: openai/gpt-4.1-2025-04-14
|
||||
- model: openai/gpt-5-2025-08-07
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
# Anthropic Model
|
||||
- model: anthropic/claude-sonnet-4-6
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
# # Anthropic Model
|
||||
# - model: anthropic/claude-sonnet-4-6
|
||||
# access_key: $ANTHROPIC_API_KEY
|
||||
|
||||
# Ollama Model (optional local fallback)
|
||||
- model: ollama/llama3.1
|
||||
|
|
@ -32,7 +32,7 @@ model_providers:
|
|||
model_aliases:
|
||||
# Default model Codex should request when launched by planoai cli-agent codex
|
||||
arch.codex.default:
|
||||
target: gpt-5-2025-08-07
|
||||
target: gpt-5.3-codex
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@ listeners:
|
|||
|
||||
model_providers:
|
||||
# OpenAI Models
|
||||
- model: openai/gpt-5-2025-08-07
|
||||
- model: openai/gpt-5.3-codex
|
||||
default: true
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
- model: openai/gpt-4.1-2025-04-14
|
||||
- model: openai/gpt-5-2025-08-07
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
# Anthropic Model
|
||||
- model: anthropic/claude-sonnet-4-6
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
# # Anthropic Model
|
||||
# - model: anthropic/claude-sonnet-4-6
|
||||
# access_key: $ANTHROPIC_API_KEY
|
||||
|
||||
# Ollama Model (optional local fallback)
|
||||
- model: ollama/llama3.1
|
||||
|
|
@ -32,7 +32,7 @@ model_providers:
|
|||
model_aliases:
|
||||
# Default model OpenCode should request when launched by planoai cli-agent opencode
|
||||
arch.opencode.default:
|
||||
target: gpt-5-2025-08-07
|
||||
target: gpt-5.3-codex
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue