feat: implement capability-aware routing for OpenAI Responses API, enhancing model selection and error handling for unsupported tools

This commit is contained in:
Musa 2026-02-25 13:45:44 -08:00
parent eed196bc81
commit 2210f69db0
No known key found for this signature in database
8 changed files with 189 additions and 39 deletions

View file

@ -14,7 +14,7 @@ use hyper::{Request, Response, StatusCode};
use opentelemetry::global; use opentelemetry::global;
use opentelemetry::trace::get_active_span; use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector; use opentelemetry_http::HeaderInjector;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument}; use tracing::{debug, info, info_span, warn, Instrument};
@ -149,6 +149,8 @@ async fn llm_chat_inner(
client_api, client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
); );
let requires_native_responses_tools =
responses_request_uses_non_function_tools(&client_request);
// If model is not specified in the request, resolve from default provider // If model is not specified in the request, resolve from default provider
let model_from_request = client_request.model().to_string(); let model_from_request = client_request.model().to_string();
@ -394,14 +396,44 @@ async fn llm_chat_inner(
// Determine final model to use // Determine final model to use
// Router returns "none" as a sentinel value when it doesn't select a specific model // Router returns "none" as a sentinel value when it doesn't select a specific model
let router_selected_model = routing_result.model_name; let router_selected_model = routing_result.model_name.clone();
let resolved_model = if router_selected_model != "none" { let resolved_model = if router_selected_model != "none" {
// Router selected a specific model via routing preferences // Router selected a specific model via routing preferences
router_selected_model router_selected_model.clone()
} else { } else {
// Router returned "none" sentinel, use validated resolved_model from request // Router returned "none" sentinel, use validated resolved_model from request
alias_resolved_model.clone() alias_resolved_model.clone()
}; };
let resolved_model = if requires_native_responses_tools {
match select_capability_compatible_model(
&llm_providers,
&resolved_model,
is_streaming_request,
)
.await
{
Some(compatible_model) => {
if compatible_model != resolved_model {
warn!(
request_id = %request_id,
selected_model = %resolved_model,
compatible_model = %compatible_model,
"selected model cannot serve responses web/file/computer tools; rerouting to compatible model"
);
}
compatible_model
}
None => {
let err_msg = "No configured model can serve OpenAI Responses API requests with non-function tools".to_string();
warn!(request_id = %request_id, error = %err_msg, "capability-aware routing failed");
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
}
} else {
resolved_model
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str()); tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
let span_name = if model_from_request == resolved_model { let span_name = if model_from_request == resolved_model {
@ -541,6 +573,81 @@ fn resolve_model_alias(
model_from_request.to_string() model_from_request.to_string()
} }
fn responses_request_uses_non_function_tools(client_request: &ProviderRequestType) -> bool {
match client_request {
ProviderRequestType::ResponsesAPIRequest(req) => req
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.any(|tool| !matches!(tool, ResponsesTool::Function { .. }))
})
.unwrap_or(false),
_ => false,
}
}
async fn model_supports_native_responses_api(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
is_streaming: bool,
) -> bool {
let upstream_path = get_upstream_path(
llm_providers,
model_name,
"/v1/responses",
model_name,
is_streaming,
)
.await;
matches!(
SupportedUpstreamAPIs::from_endpoint(&upstream_path),
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
)
}
async fn select_capability_compatible_model(
llm_providers: &Arc<RwLock<LlmProviders>>,
preferred_model: &str,
is_streaming: bool,
) -> Option<String> {
if model_supports_native_responses_api(llm_providers, preferred_model, is_streaming).await {
return Some(preferred_model.to_string());
}
let (default_candidate, ordered_candidates): (Option<String>, Vec<String>) = {
let providers = llm_providers.read().await;
let default_candidate = providers.default().map(|p| p.name.clone());
let mut seen = HashSet::new();
let mut candidates = Vec::new();
for (key, provider) in providers.iter() {
if key != &provider.name || provider.internal == Some(true) {
continue;
}
if seen.insert(provider.name.clone()) {
candidates.push(provider.name.clone());
}
}
(default_candidate, candidates)
};
if let Some(default_model) = default_candidate {
if model_supports_native_responses_api(llm_providers, &default_model, is_streaming).await {
return Some(default_model);
}
}
for candidate in ordered_candidates {
if model_supports_native_responses_api(llm_providers, &candidate, is_streaming).await {
return Some(candidate);
}
}
None
}
/// Calculates the upstream path for the provider based on the model name. /// Calculates the upstream path for the provider based on the model name.
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix, /// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
/// then uses target_endpoint_for_provider to calculate the correct upstream path. /// then uses target_endpoint_for_provider to calculate the correct upstream path.

View file

@ -40,13 +40,16 @@ pub async fn router_chat_get_upstream_model(
) -> Result<RoutingResult, RoutingError> { ) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request) // Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone(); let routing_metadata = client_request.metadata().clone();
let fallback_messages = client_request.get_messages();
// Convert to ChatCompletionsRequest for routing (regardless of input type) // Convert to ChatCompletionsRequest for routing when possible.
let chat_request = match ProviderRequestType::try_from(( // If conversion fails for unsupported Responses tools (e.g. `custom`),
// route based on normalized OpenAI messages extracted from the original request.
let routing_messages = match ProviderRequestType::try_from((
client_request, client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions), &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) { )) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req.messages,
Ok( Ok(
ProviderRequestType::MessagesRequest(_) ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverse(_)
@ -59,6 +62,14 @@ pub async fn router_chat_get_upstream_model(
)); ));
} }
Err(err) => { Err(err) => {
let err_text = err.to_string();
if err_text.contains("Unsupported conversion") {
warn!(
"routing conversion unsupported; falling back to routing with normalized request messages: {}",
err_text
);
fallback_messages
} else {
warn!( warn!(
"failed to convert request to ChatCompletionsRequest: {}", "failed to convert request to ChatCompletionsRequest: {}",
err err
@ -68,11 +79,12 @@ pub async fn router_chat_get_upstream_model(
err err
))); )));
} }
}
}; };
debug!( debug!(
request = %serde_json::to_string(&chat_request).unwrap(), message_count = routing_messages.len(),
"router request" "router request prepared"
); );
// Extract usage preferences from metadata // Extract usage preferences from metadata
@ -87,10 +99,7 @@ pub async fn router_chat_get_upstream_model(
.and_then(|s| serde_yaml::from_str(s).ok()); .and_then(|s| serde_yaml::from_str(s).ok());
// Prepare log message with latest message from chat request // Prepare log message with latest message from chat request
let latest_message_for_log = chat_request let latest_message_for_log = routing_messages.last().map_or("None".to_string(), |msg| {
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content msg.content
.as_ref() .as_ref()
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n")) .map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
@ -120,7 +129,7 @@ pub async fn router_chat_get_upstream_model(
// Attempt to determine route using the router service // Attempt to determine route using the router service
let routing_result = router_service let routing_result = router_service
.determine_route( .determine_route(
&chat_request.messages, &routing_messages,
traceparent, traceparent,
usage_preferences, usage_preferences,
request_id, request_id,

View file

@ -209,10 +209,13 @@ pub struct AudioConfig {
} }
/// Text configuration /// Text configuration
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextConfig { pub struct TextConfig {
/// Text format configuration /// Text format configuration
pub format: TextFormat, pub format: Option<TextFormat>,
/// Controls response verbosity for models that support it.
pub verbosity: Option<String>,
} }
/// Text format /// Text format
@ -302,6 +305,12 @@ pub enum Tool {
display_height_px: Option<i32>, display_height_px: Option<i32>,
display_number: Option<i32>, display_number: Option<i32>,
}, },
/// Custom tool (forward-compatible passthrough for provider-specific tools)
#[serde(rename = "custom")]
Custom {
#[serde(flatten)]
config: serde_json::Value,
},
} }
/// Ranking options for file search /// Ranking options for file search

View file

@ -236,7 +236,8 @@ impl ResponsesAPIStreamBuffer {
}), }),
store: Some(true), store: Some(true),
text: Some(TextConfig { text: Some(TextConfig {
format: TextFormat::Text, format: Some(TextFormat::Text),
verbosity: None,
}), }),
audio: None, audio: None,
modalities: None, modalities: None,

View file

@ -466,6 +466,9 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion( ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string() "Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)), )),
ResponsesTool::Custom { .. } => Err(TransformError::UnsupportedConversion(
"Custom tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
} }
}).collect::<Result<Vec<_>, _>>() }).collect::<Result<Vec<_>, _>>()
}).transpose()?, }).transpose()?,

View file

@ -1139,6 +1139,27 @@ impl HttpContext for StreamContext {
let current_time = get_current_time().unwrap(); let current_time = get_current_time().unwrap();
if end_of_stream && body_size == 0 { if end_of_stream && body_size == 0 {
// Flush any buffered partial SSE event on stream end.
// This handles cases where the last logical SSE event (for example, response.completed)
// was split across chunks and the tail only arrives at end-of-stream.
if self.streaming_response {
let provider_id = self.get_provider_id();
match self.handle_streaming_response(&[], provider_id) {
Ok(serialized_body) => {
if !serialized_body.is_empty() {
self.set_http_response_body(0, 0, &serialized_body);
debug!(
"request_id={}: flushed buffered streaming bytes on end_of_stream, size={}",
self.request_identifier(),
serialized_body.len()
);
}
}
Err(_) => {
// Ignore flush errors and proceed with end-of-request handling.
}
}
}
debug!( debug!(
"request_id={}: response body complete, total_bytes={}", "request_id={}: response body complete, total_bytes={}",
self.request_identifier(), self.request_identifier(),

View file

@ -7,22 +7,22 @@ listeners:
model_providers: model_providers:
# OpenAI Models # OpenAI Models
- model: openai/gpt-5-2025-08-07 - model: openai/gpt-5.3-codex
default: true default: true
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
routing_preferences: routing_preferences:
- name: code generation - name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
- model: openai/gpt-4.1-2025-04-14 - model: openai/gpt-5-2025-08-07
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
routing_preferences: routing_preferences:
- name: code understanding - name: code understanding
description: understand and explain existing code snippets, functions, or libraries description: understand and explain existing code snippets, functions, or libraries
# Anthropic Model # # Anthropic Model
- model: anthropic/claude-sonnet-4-6 # - model: anthropic/claude-sonnet-4-6
access_key: $ANTHROPIC_API_KEY # access_key: $ANTHROPIC_API_KEY
# Ollama Model (optional local fallback) # Ollama Model (optional local fallback)
- model: ollama/llama3.1 - model: ollama/llama3.1
@ -32,7 +32,7 @@ model_providers:
model_aliases: model_aliases:
# Default model Codex should request when launched by planoai cli-agent codex # Default model Codex should request when launched by planoai cli-agent codex
arch.codex.default: arch.codex.default:
target: gpt-5-2025-08-07 target: gpt-5.3-codex
tracing: tracing:
random_sampling: 100 random_sampling: 100

View file

@ -7,22 +7,22 @@ listeners:
model_providers: model_providers:
# OpenAI Models # OpenAI Models
- model: openai/gpt-5-2025-08-07 - model: openai/gpt-5.3-codex
default: true default: true
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
routing_preferences: routing_preferences:
- name: code generation - name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
- model: openai/gpt-4.1-2025-04-14 - model: openai/gpt-5-2025-08-07
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
routing_preferences: routing_preferences:
- name: code understanding - name: code understanding
description: understand and explain existing code snippets, functions, or libraries description: understand and explain existing code snippets, functions, or libraries
# Anthropic Model # # Anthropic Model
- model: anthropic/claude-sonnet-4-6 # - model: anthropic/claude-sonnet-4-6
access_key: $ANTHROPIC_API_KEY # access_key: $ANTHROPIC_API_KEY
# Ollama Model (optional local fallback) # Ollama Model (optional local fallback)
- model: ollama/llama3.1 - model: ollama/llama3.1
@ -32,7 +32,7 @@ model_providers:
model_aliases: model_aliases:
# Default model OpenCode should request when launched by planoai cli-agent opencode # Default model OpenCode should request when launched by planoai cli-agent opencode
arch.opencode.default: arch.opencode.default:
target: gpt-5-2025-08-07 target: gpt-5.3-codex
tracing: tracing:
random_sampling: 100 random_sampling: 100