merge main and resolve conflicts

This commit is contained in:
Adil Hafeez 2026-03-11 18:57:36 +00:00
commit bcb7f60005
26 changed files with 2145 additions and 213 deletions

View file

@ -18,7 +18,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
mod router;
pub(crate) mod router;
use crate::app_state::AppState;
use crate::handlers::request::extract_request_id;
@ -120,6 +120,7 @@ async fn llm_chat_inner(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
} = parsed;
// Record LLM-specific span attributes
@ -186,6 +187,7 @@ async fn llm_chat_inner(
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
)
.await
}
@ -245,6 +247,7 @@ struct PreparedRequest {
temperature: Option<f32>,
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
}
/// Parse the body, resolve the model alias, and validate the model exists.
@ -256,7 +259,7 @@ async fn parse_and_validate_request(
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: &Arc<RwLock<LlmProviders>>,
) -> Result<PreparedRequest, Response<BoxBody<Bytes, hyper::Error>>> {
let chat_request_bytes = request
let raw_bytes = request
.collect()
.await
.map_err(|_| {
@ -267,10 +270,21 @@ async fn parse_and_validate_request(
.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
body = %String::from_utf8_lossy(&raw_bytes),
"request body received"
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
},
)?;
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
warn!(path = %request_path, "unsupported endpoint");
let mut r = Response::new(full(format!("Unsupported endpoint: {}", request_path)));
@ -296,6 +310,7 @@ async fn parse_and_validate_request(
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases);
let (provider_id, _) = get_provider_info(llm_providers, &alias_resolved_model).await;
// Validate model exists in configuration
if llm_providers
@ -332,6 +347,14 @@ async fn parse_and_validate_request(
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("removed archgw_preference_config from metadata");
}
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
client_request.normalize_for_upstream(provider_id, &upstream_api);
}
Ok(PreparedRequest {
client_request,
@ -344,6 +367,7 @@ async fn parse_and_validate_request(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
})
}

View file

@ -10,6 +10,7 @@ use crate::tracing::routing;
pub struct RoutingResult {
pub model_name: String,
pub route_name: Option<String>,
}
pub struct RoutingError {
@ -37,6 +38,7 @@ pub async fn router_chat_get_upstream_model(
traceparent: &str,
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
@ -75,16 +77,21 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Extract usage preferences from metadata
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
// Use inline preferences if provided, otherwise fall back to metadata extraction
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
{
inline_usage_preferences
} else {
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok())
};
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
@ -133,9 +140,12 @@ pub async fn router_chat_get_upstream_model(
match routing_result {
Ok(route) => match route {
Some((_, model_name)) => {
Some((route_name, model_name)) => {
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult { model_name })
Ok(RoutingResult {
model_name,
route_name: Some(route_name),
})
}
None => {
// No route determined, return sentinel value "none"
@ -145,6 +155,7 @@ pub async fn router_chat_get_upstream_model(
Ok(RoutingResult {
model_name: "none".to_string(),
route_name: None,
})
}
},

View file

@ -5,6 +5,7 @@ pub mod llm;
pub mod models;
pub mod request;
pub mod response;
pub mod routing_service;
pub mod utils;
#[cfg(test)]

View file

@ -0,0 +1,357 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use std::sync::Arc;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::llm::router::router_chat_get_upstream_model;
use crate::router::llm::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
/// before re-serializing so downstream parsers don't see the non-standard field.
///
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
.and_then(|policy_value| {
if warn_on_size {
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
warn!(
size_bytes = policy_str.len(),
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
"routing_policy exceeds recommended size limit"
);
}
}
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
Ok(prefs) => {
info!(
num_models = prefs.len(),
"using inline routing_policy from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_policy");
None
}
}
});
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
}
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
model: String,
route: Option<String>,
trace_id: String,
}
pub async fn routing_decision(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_path: String,
span_attributes: Arc<Option<SpanAttributes>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_headers = request.headers().clone();
let request_id: String = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
let request_span = info_span!(
"routing_decision",
component = "routing",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
);
routing_decision_inner(
request,
router_service,
request_id,
request_path,
request_headers,
custom_attrs,
)
.instrument(request_span)
.await
}
async fn routing_decision_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_id: String,
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent
let traceparent: String = match request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(tp) => tp,
None => {
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %generated_tp,
"TRACE_PARENT header missing, generated new traceparent"
);
generated_tp
}
};
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
let trace_id = traceparent
.split('-')
.nth(1)
.unwrap_or("unknown")
.to_string();
// Parse request body
let raw_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&raw_bytes),
"routing decision request body received"
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(error = %err, "failed to parse request for routing decision");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
router_service,
client_request,
&traceparent,
&request_path,
&request_id,
inline_preferences,
)
.await;
match routing_result {
Ok(result) => {
let response = RoutingDecisionResponse {
model: result.model_name,
route: result.route_name,
trace_id,
};
info!(
model = %response.model,
route = ?response.route,
"routing decision completed"
);
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap())
}
Err(err) => {
warn!(error = %err.message, "routing decision failed");
Ok(BrightStaffError::InternalServerError(err.message).into_response())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
let extra = if extra_fields.is_empty() {
String::new()
} else {
format!(", {}", extra_fields)
};
format!(
r#"{{"model": "gpt-4o-mini", "messages": [{{"role": "user", "content": "hello"}}]{}}}"#,
extra
)
.into_bytes()
}
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_valid_policy() {
let policy = r#""routing_policy": [
{
"model": "openai/gpt-4o",
"routing_preferences": [
{"name": "coding", "description": "code generation tasks"}
]
},
{
"model": "openai/gpt-4o-mini",
"routing_preferences": [
{"name": "general", "description": "general questions"}
]
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
assert_eq!(prefs[0].model, "openai/gpt-4o");
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
assert_eq!(prefs[1].routing_preferences[0].name, "general");
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
#[test]
fn extract_routing_policy_invalid_policy_returns_none() {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
// routing_policy should still be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_invalid_json_returns_error() {
let body = b"not valid json";
let result = extract_routing_policy(body, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse JSON"));
}
#[test]
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {
model: "openai/gpt-4o".to_string(),
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
}
#[test]
fn routing_decision_response_serialization_no_route() {
let response = RoutingDecisionResponse {
model: "none".to_string(),
route: None,
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert!(parsed["route"].is_null());
}
}

View file

@ -3,6 +3,7 @@ use brightstaff::handlers::agents::orchestrator::agent_chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
@ -221,6 +222,24 @@ async fn route(
}
}
// --- Routing decision routes (/routing/...) ---
if let Some(stripped) = path.strip_prefix("/routing") {
let stripped = stripped.to_string();
if matches!(
stripped.as_str(),
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return routing_decision(
req,
Arc::clone(&state.router_service),
stripped,
Arc::clone(&state.span_attributes),
)
.with_context(parent_cx)
.await;
}
}
// --- Standard routes ---
match (req.method(), path.as_str()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {

View file

@ -112,6 +112,7 @@ pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
}]),
})]
}
InputParam::SingleItem(item) => vec![item.clone()],
InputParam::Items(items) => items.clone(),
}
}
@ -128,3 +129,101 @@ pub async fn retrieve_and_combine_input(
let combined_input = storage.merge(&prev_state, current_input);
Ok(combined_input)
}
#[cfg(test)]
mod tests {
use super::extract_input_items;
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
};
#[test]
fn test_extract_input_items_converts_text_to_user_message_item() {
let extracted = extract_input_items(&InputParam::Text("hello world".to_string()));
assert_eq!(extracted.len(), 1);
let InputItem::Message(message) = &extracted[0] else {
panic!("expected InputItem::Message");
};
assert!(matches!(message.role, MessageRole::User));
let MessageContent::Items(items) = &message.content else {
panic!("expected MessageContent::Items");
};
assert_eq!(items.len(), 1);
let InputContent::InputText { text } = &items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(text, "hello world");
}
#[test]
fn test_extract_input_items_preserves_single_item() {
let item = InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "assistant note".to_string(),
}]),
});
let extracted = extract_input_items(&InputParam::SingleItem(item.clone()));
assert_eq!(extracted.len(), 1);
let InputItem::Message(message) = &extracted[0] else {
panic!("expected InputItem::Message");
};
assert!(matches!(message.role, MessageRole::Assistant));
let MessageContent::Items(items) = &message.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text } = &items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(text, "assistant note");
}
#[test]
fn test_extract_input_items_preserves_items_list() {
let items = vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "first".to_string(),
}]),
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "second".to_string(),
}]),
}),
];
let extracted = extract_input_items(&InputParam::Items(items.clone()));
assert_eq!(extracted.len(), items.len());
let InputItem::Message(first) = &extracted[0] else {
panic!("expected first item to be message");
};
assert!(matches!(first.role, MessageRole::User));
let MessageContent::Items(first_items) = &first.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text: first_text } = &first_items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(first_text, "first");
let InputItem::Message(second) = &extracted[1] else {
panic!("expected second item to be message");
};
assert!(matches!(second.role, MessageRole::Assistant));
let MessageContent::Items(second_items) = &second.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text: second_text } = &second_items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(second_text, "second");
}
}

View file

@ -108,7 +108,7 @@ pub struct ChatCompletionsRequest {
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
pub web_search_options: Option<Value>,
// VLLM-specific parameters (used by Arch-Function)
pub top_k: Option<u32>,

View file

@ -116,6 +116,8 @@ pub enum InputParam {
Text(String),
/// Array of input items (messages, references, outputs, etc.)
Items(Vec<InputItem>),
/// Single input item (some clients send object instead of array)
SingleItem(InputItem),
}
/// Input item - can be a message, item reference, function call output, etc.
@ -130,12 +132,20 @@ pub enum InputItem {
item_type: String,
id: String,
},
/// Function call emitted by model in prior turn
FunctionCall {
#[serde(rename = "type")]
item_type: String,
name: String,
arguments: String,
call_id: String,
},
/// Function call output
FunctionCallOutput {
#[serde(rename = "type")]
item_type: String,
call_id: String,
output: String,
output: serde_json::Value,
},
}
@ -166,6 +176,7 @@ pub enum MessageRole {
Assistant,
System,
Developer,
Tool,
}
/// Input content types
@ -173,6 +184,7 @@ pub enum MessageRole {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContent {
/// Text input
#[serde(rename = "input_text", alias = "text", alias = "output_text")]
InputText { text: String },
/// Image input via URL
InputImage {
@ -180,6 +192,7 @@ pub enum InputContent {
detail: Option<String>,
},
/// File input via URL
#[serde(rename = "input_file", alias = "file")]
InputFile { file_url: String },
/// Audio input
InputAudio {
@ -207,10 +220,11 @@ pub struct AudioConfig {
}
/// Text configuration
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextConfig {
/// Text format configuration
pub format: TextFormat,
pub format: Option<TextFormat>,
}
/// Text format
@ -285,6 +299,7 @@ pub enum Tool {
filters: Option<serde_json::Value>,
},
/// Web search tool
#[serde(rename = "web_search", alias = "web_search_preview")]
WebSearchPreview {
domains: Option<Vec<String>>,
search_context_size: Option<String>,
@ -298,6 +313,12 @@ pub enum Tool {
display_height_px: Option<i32>,
display_number: Option<i32>,
},
/// Custom tool (provider/SDK-specific tool contract)
Custom {
name: Option<String>,
description: Option<String>,
format: Option<serde_json::Value>,
},
}
/// Ranking options for file search
@ -1015,6 +1036,30 @@ pub struct ListInputItemsResponse {
// ProviderRequest Implementation
// ============================================================================
fn append_input_content_text(buffer: &mut String, content: &InputContent) {
match content {
InputContent::InputText { text } => buffer.push_str(text),
InputContent::InputImage { .. } => buffer.push_str("[Image]"),
InputContent::InputFile { .. } => buffer.push_str("[File]"),
InputContent::InputAudio { .. } => buffer.push_str("[Audio]"),
}
}
fn append_content_items_text(buffer: &mut String, content_items: &[InputContent]) {
for content in content_items {
// Preserve existing behavior: each content item is prefixed with a space.
buffer.push(' ');
append_input_content_text(buffer, content);
}
}
fn append_message_content_text(buffer: &mut String, content: &MessageContent) {
match content {
MessageContent::Text(text) => buffer.push_str(text),
MessageContent::Items(content_items) => append_content_items_text(buffer, content_items),
}
}
impl ProviderRequest for ResponsesAPIRequest {
fn model(&self) -> &str {
&self.model
@ -1031,36 +1076,27 @@ impl ProviderRequest for ResponsesAPIRequest {
fn extract_messages_text(&self) -> String {
match &self.input {
InputParam::Text(text) => text.clone(),
InputParam::Items(items) => {
items.iter().fold(String::new(), |acc, item| {
match item {
InputItem::Message(msg) => {
let content_text = match &msg.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Items(content_items) => {
content_items.iter().fold(String::new(), |acc, content| {
acc + " "
+ &match content {
InputContent::InputText { text } => text.clone(),
InputContent::InputImage { .. } => {
"[Image]".to_string()
}
InputContent::InputFile { .. } => {
"[File]".to_string()
}
InputContent::InputAudio { .. } => {
"[Audio]".to_string()
}
}
})
}
};
acc + " " + &content_text
}
// Skip non-message items (references, outputs, etc.)
_ => acc,
InputParam::SingleItem(item) => {
// Normalize single-item input for extraction behavior parity.
match item {
InputItem::Message(msg) => {
let mut extracted = String::new();
append_message_content_text(&mut extracted, &msg.content);
extracted
}
})
_ => String::new(),
}
}
InputParam::Items(items) => {
let mut extracted = String::new();
for item in items {
if let InputItem::Message(msg) = item {
// Preserve existing behavior: each message is prefixed with a space.
extracted.push(' ');
append_message_content_text(&mut extracted, &msg.content);
}
}
extracted
}
}
}
@ -1068,6 +1104,20 @@ impl ProviderRequest for ResponsesAPIRequest {
fn get_recent_user_message(&self) -> Option<String> {
match &self.input {
InputParam::Text(text) => Some(text.clone()),
InputParam::SingleItem(item) => match item {
InputItem::Message(msg) if matches!(msg.role, MessageRole::User) => {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Items(content_items) => {
content_items.iter().find_map(|content| match content {
InputContent::InputText { text } => Some(text.clone()),
_ => None,
})
}
}
}
_ => None,
},
InputParam::Items(items) => {
items.iter().rev().find_map(|item| {
match item {
@ -1097,6 +1147,9 @@ impl ProviderRequest for ResponsesAPIRequest {
.iter()
.filter_map(|tool| match tool {
Tool::Function { name, .. } => Some(name.clone()),
Tool::Custom {
name: Some(name), ..
} => Some(name.clone()),
// Other tool types don't have user-defined names
_ => None,
})
@ -1366,6 +1419,7 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_response_output_text_delta_deserialization() {
@ -1506,4 +1560,87 @@ mod tests {
_ => panic!("Expected ResponseCompleted event"),
}
}
#[test]
fn test_request_deserializes_custom_tool() {
let request = json!({
"model": "gpt-5.3-codex",
"input": "apply the patch",
"tools": [
{
"type": "custom",
"name": "run_patch",
"description": "Apply patch text",
"format": {
"kind": "patch",
"version": "v1"
}
}
]
});
let bytes = serde_json::to_vec(&request).unwrap();
let parsed = ResponsesAPIRequest::try_from(bytes.as_slice()).unwrap();
let tools = parsed.tools.expect("tools should be present");
assert_eq!(tools.len(), 1);
match &tools[0] {
Tool::Custom {
name,
description,
format,
} => {
assert_eq!(name.as_deref(), Some("run_patch"));
assert_eq!(description.as_deref(), Some("Apply patch text"));
assert!(format.is_some());
}
_ => panic!("expected custom tool"),
}
}
#[test]
fn test_request_deserializes_web_search_tool_alias() {
let request = json!({
"model": "gpt-5.3-codex",
"input": "find repository info",
"tools": [
{
"type": "web_search",
"domains": ["github.com"],
"search_context_size": "medium"
}
]
});
let bytes = serde_json::to_vec(&request).unwrap();
let parsed = ResponsesAPIRequest::try_from(bytes.as_slice()).unwrap();
let tools = parsed.tools.expect("tools should be present");
assert_eq!(tools.len(), 1);
match &tools[0] {
Tool::WebSearchPreview {
domains,
search_context_size,
..
} => {
assert_eq!(domains.as_ref().map(Vec::len), Some(1));
assert_eq!(search_context_size.as_deref(), Some("medium"));
}
_ => panic!("expected web search preview tool"),
}
}
#[test]
fn test_request_deserializes_text_config_without_format() {
let request = json!({
"model": "gpt-5.3-codex",
"input": "hello",
"text": {}
});
let bytes = serde_json::to_vec(&request).unwrap();
let parsed = ResponsesAPIRequest::try_from(bytes.as_slice()).unwrap();
assert!(parsed.text.is_some());
assert!(parsed.text.unwrap().format.is_none());
}
}

View file

@ -74,6 +74,7 @@ pub struct ResponsesAPIStreamBuffer {
/// Lifecycle state flags
created_emitted: bool,
in_progress_emitted: bool,
finalized: bool,
/// Track which output items we've added
output_items_added: HashMap<i32, String>, // output_index -> item_id
@ -109,6 +110,7 @@ impl ResponsesAPIStreamBuffer {
upstream_response_metadata: None,
created_emitted: false,
in_progress_emitted: false,
finalized: false,
output_items_added: HashMap::new(),
text_content: HashMap::new(),
function_arguments: HashMap::new(),
@ -236,7 +238,7 @@ impl ResponsesAPIStreamBuffer {
}),
store: Some(true),
text: Some(TextConfig {
format: TextFormat::Text,
format: Some(TextFormat::Text),
}),
audio: None,
modalities: None,
@ -255,8 +257,38 @@ impl ResponsesAPIStreamBuffer {
/// Finalize the response by emitting all *.done events and response.completed.
/// Call this when the stream is complete (after seeing [DONE] or end_of_stream).
pub fn finalize(&mut self) {
// Idempotent finalize: avoid duplicate response.completed loops.
if self.finalized {
return;
}
self.finalized = true;
let mut events = Vec::new();
// Ensure lifecycle prelude is emitted even if finalize is triggered
// by finish_reason before any prior delta was processed.
if !self.created_emitted {
if self.response_id.is_none() {
self.response_id = Some(format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace("-", "")
));
self.created_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
);
self.model = Some("unknown".to_string());
}
events.push(self.create_response_created_event());
self.created_emitted = true;
}
if !self.in_progress_emitted {
events.push(self.create_response_in_progress_event());
self.in_progress_emitted = true;
}
// Emit done events for all accumulated content
// Text content done events
@ -443,6 +475,12 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
}
};
// Explicit completion marker from transform layer.
if matches!(stream_event.as_ref(), ResponsesAPIStreamEvent::Done { .. }) {
self.finalize();
return;
}
let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
@ -789,4 +827,30 @@ mod tests {
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
}
#[test]
fn test_finish_reason_without_done_still_finalizes_once() {
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = ResponsesAPIStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output = String::from_utf8_lossy(&buffer.to_bytes()).to_string();
let completed_count = output.matches("event: response.completed").count();
assert_eq!(
completed_count, 1,
"response.completed should be emitted exactly once"
);
}
}

View file

@ -184,8 +184,8 @@ impl SupportedAPIsFromClient {
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
// For Responses API, check if provider supports it, otherwise translate to chat/completions
match provider_id {
// OpenAI and compatible providers that support /v1/responses
ProviderId::OpenAI => route_by_provider("/responses"),
// Providers that support /v1/responses natively
ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"),
// All other providers: translate to /chat/completions
_ => route_by_provider("/chat/completions"),
}
@ -654,4 +654,19 @@ mod tests {
"/custom/azure/path/gpt-4-deployment/chat/completions?api-version=2025-01-01-preview"
);
}
#[test]
fn test_responses_api_targets_xai_native_responses_endpoint() {
let api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
assert_eq!(
api.target_endpoint_for_provider(
&ProviderId::XAI,
"/v1/responses",
"grok-4-1-fast-reasoning",
false,
None
),
"/v1/responses"
);
}
}

View file

@ -166,10 +166,11 @@ impl ProviderId {
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI Responses API - only OpenAI supports this
(ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses)
}
// OpenAI Responses API - OpenAI and xAI support this natively
(
ProviderId::OpenAI | ProviderId::XAI,
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses),
// Amazon Bedrock natively supports Bedrock APIs
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
@ -328,4 +329,16 @@ mod tests {
"AmazonBedrock should have models (mapped to amazon)"
);
}
#[test]
fn test_xai_uses_responses_api_for_responses_clients() {
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
let upstream = ProviderId::XAI.compatible_api_for_client(&client_api, false);
assert!(matches!(
upstream,
SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses)
));
}
}

View file

@ -5,6 +5,7 @@ use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::apis::openai_responses::ResponsesAPIRequest;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::ProviderId;
use serde_json::Value;
use std::collections::HashMap;
@ -70,6 +71,25 @@ impl ProviderRequestType {
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
}
}
/// Apply provider-specific request normalization before sending upstream.
pub fn normalize_for_upstream(
&mut self,
provider_id: ProviderId,
upstream_api: &SupportedUpstreamAPIs,
) {
if provider_id == ProviderId::XAI
&& matches!(
upstream_api,
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
)
{
if let Self::ChatCompletionsRequest(req) = self {
// xAI's legacy live-search shape is deprecated on chat/completions.
req.web_search_options = None;
}
}
}
}
impl ProviderRequest for ProviderRequestType {
@ -787,6 +807,62 @@ mod tests {
}
}
#[test]
fn test_normalize_for_upstream_xai_clears_chat_web_search_options() {
use crate::apis::openai::{Message, MessageContent, OpenAIApi, Role};
let mut request = ProviderRequestType::ChatCompletionsRequest(ChatCompletionsRequest {
model: "grok-4".to_string(),
messages: vec![Message {
role: Role::User,
content: Some(MessageContent::Text("hello".to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
}],
web_search_options: Some(serde_json::json!({"search_context_size":"medium"})),
..Default::default()
});
request.normalize_for_upstream(
ProviderId::XAI,
&SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
);
let ProviderRequestType::ChatCompletionsRequest(req) = request else {
panic!("expected chat request");
};
assert!(req.web_search_options.is_none());
}
#[test]
fn test_normalize_for_upstream_non_xai_keeps_chat_web_search_options() {
use crate::apis::openai::{Message, MessageContent, OpenAIApi, Role};
let mut request = ProviderRequestType::ChatCompletionsRequest(ChatCompletionsRequest {
model: "gpt-4o".to_string(),
messages: vec![Message {
role: Role::User,
content: Some(MessageContent::Text("hello".to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
}],
web_search_options: Some(serde_json::json!({"search_context_size":"medium"})),
..Default::default()
});
request.normalize_for_upstream(
ProviderId::OpenAI,
&SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
);
let ProviderRequestType::ChatCompletionsRequest(req) = request else {
panic!("expected chat request");
};
assert!(req.web_search_options.is_some());
}
#[test]
fn test_responses_api_to_anthropic_messages_conversion() {
use crate::apis::anthropic::AnthropicApi::Messages;

View file

@ -10,7 +10,8 @@ use crate::apis::anthropic::{
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
ChatCompletionsRequest, FunctionCall as OpenAIFunctionCall, Message, MessageContent, Role,
Tool, ToolCall as OpenAIToolCall, ToolChoice, ToolChoiceType,
};
use crate::apis::openai_responses::{
@ -65,6 +66,14 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
Ok(messages)
}
InputParam::SingleItem(item) => {
// Some clients send a single object instead of an array.
let nested = ResponsesInputConverter {
input: InputParam::Items(vec![item]),
instructions: converter.instructions,
};
Vec::<Message>::try_from(nested)
}
InputParam::Items(items) => {
// Convert input items to messages
let mut converted_messages = Vec::new();
@ -82,82 +91,145 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
// Convert each input item
for item in items {
if let InputItem::Message(input_msg) = item {
let role = match input_msg.role {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
};
match item {
InputItem::Message(input_msg) => {
let role = match input_msg.role {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
MessageRole::Tool => Role::Tool,
};
// Convert content based on MessageContent type
let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content
MessageContent::Text(text.clone())
}
crate::apis::openai_responses::MessageContent::Items(content_items) => {
// Check if it's a single text item (can use simple text format)
if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0] {
MessageContent::Text(text.clone())
// Convert content based on MessageContent type
let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content
MessageContent::Text(text.clone())
}
crate::apis::openai_responses::MessageContent::Items(
content_items,
) => {
// Check if it's a single text item (can use simple text format)
if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0]
{
MessageContent::Text(text.clone())
} else {
// Single non-text item - use parts format
MessageContent::Parts(
content_items
.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text {
text: text.clone(),
})
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
},
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect(),
)
}
} else {
// Single non-text item - use parts format
// Multiple content items - convert to parts
MessageContent::Parts(
content_items.iter()
content_items
.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
Some(crate::apis::openai::ContentPart::Text {
text: text.clone(),
})
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
}
},
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect()
.collect(),
)
}
} else {
// Multiple content items - convert to parts
MessageContent::Parts(
content_items
.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text {
text: text.clone(),
})
}
InputContent::InputImage { image_url, .. } => Some(
crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
},
},
),
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect(),
)
}
};
converted_messages.push(Message {
role,
content: Some(content),
name: None,
tool_call_id: None,
tool_calls: None,
});
}
InputItem::FunctionCallOutput {
item_type: _,
call_id,
output,
} => {
// Preserve tool result so upstream models do not re-issue the same tool call.
let output_text = match output {
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(&other).unwrap_or_default(),
};
converted_messages.push(Message {
role: Role::Tool,
content: Some(MessageContent::Text(output_text)),
name: None,
tool_call_id: Some(call_id),
tool_calls: None,
});
}
InputItem::FunctionCall {
item_type: _,
name,
arguments,
call_id,
} => {
let tool_call = OpenAIToolCall {
id: call_id,
call_type: "function".to_string(),
function: OpenAIFunctionCall { name, arguments },
};
// Prefer attaching tool_calls to the preceding assistant message when present.
if let Some(last) = converted_messages.last_mut() {
if matches!(last.role, Role::Assistant) {
if let Some(existing) = &mut last.tool_calls {
existing.push(tool_call);
} else {
last.tool_calls = Some(vec![tool_call]);
}
continue;
}
}
};
converted_messages.push(Message {
role,
content: Some(content),
name: None,
tool_call_id: None,
tool_calls: None,
});
converted_messages.push(Message {
role: Role::Assistant,
content: None,
name: None,
tool_call_id: None,
tool_calls: Some(vec![tool_call]),
});
}
InputItem::ItemReference { .. } => {
// Item references/unknown entries are metadata-like and can be skipped
// for chat-completions conversion.
}
}
}
@ -397,6 +469,170 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
fn normalize_function_parameters(
parameters: Option<serde_json::Value>,
fallback_extra: Option<serde_json::Value>,
) -> serde_json::Value {
// ChatCompletions function tools require JSON Schema with top-level type=object.
let mut base = serde_json::json!({
"type": "object",
"properties": {},
});
if let Some(serde_json::Value::Object(mut obj)) = parameters {
// Enforce a valid object schema shape regardless of upstream tool format.
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
if !obj.contains_key("properties") {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
base = serde_json::Value::Object(obj);
}
if let Some(extra) = fallback_extra {
if let serde_json::Value::Object(ref mut map) = base {
map.insert("x-custom-format".to_string(), extra);
}
}
base
}
let mut converted_chat_tools: Vec<Tool> = Vec::new();
let mut web_search_options: Option<serde_json::Value> = None;
if let Some(tools) = req.tools.clone() {
for (idx, tool) in tools.into_iter().enumerate() {
match tool {
ResponsesTool::Function {
name,
description,
parameters,
strict,
} => converted_chat_tools.push(Tool {
tool_type: "function".to_string(),
function: crate::apis::openai::Function {
name,
description,
parameters: normalize_function_parameters(parameters, None),
strict,
},
}),
ResponsesTool::WebSearchPreview {
search_context_size,
user_location,
..
} => {
if web_search_options.is_none() {
let user_location_value = user_location.map(|loc| {
let mut approx = serde_json::Map::new();
if let Some(city) = loc.city {
approx.insert(
"city".to_string(),
serde_json::Value::String(city),
);
}
if let Some(country) = loc.country {
approx.insert(
"country".to_string(),
serde_json::Value::String(country),
);
}
if let Some(region) = loc.region {
approx.insert(
"region".to_string(),
serde_json::Value::String(region),
);
}
if let Some(timezone) = loc.timezone {
approx.insert(
"timezone".to_string(),
serde_json::Value::String(timezone),
);
}
serde_json::json!({
"type": loc.location_type,
"approximate": serde_json::Value::Object(approx),
})
});
let mut web_search = serde_json::Map::new();
if let Some(size) = search_context_size {
web_search.insert(
"search_context_size".to_string(),
serde_json::Value::String(size),
);
}
if let Some(location) = user_location_value {
web_search.insert("user_location".to_string(), location);
}
web_search_options = Some(serde_json::Value::Object(web_search));
}
}
ResponsesTool::Custom {
name,
description,
format,
} => {
// Custom tools do not have a strict ChatCompletions equivalent for all
// providers. Map them to a permissive function tool for compatibility.
let tool_name = name.unwrap_or_else(|| format!("custom_tool_{}", idx + 1));
let parameters = normalize_function_parameters(
Some(serde_json::json!({
"type": "object",
"properties": {
"input": { "type": "string" }
},
"required": ["input"],
"additionalProperties": true,
})),
format,
);
converted_chat_tools.push(Tool {
tool_type: "function".to_string(),
function: crate::apis::openai::Function {
name: tool_name,
description,
parameters,
strict: Some(false),
},
});
}
ResponsesTool::FileSearch { .. } => {
return Err(TransformError::UnsupportedConversion(
"FileSearch tool is not supported in ChatCompletions API. Only function/custom/web search tools are supported in this conversion."
.to_string(),
));
}
ResponsesTool::CodeInterpreter => {
return Err(TransformError::UnsupportedConversion(
"CodeInterpreter tool is not supported in ChatCompletions API conversion."
.to_string(),
));
}
ResponsesTool::Computer { .. } => {
return Err(TransformError::UnsupportedConversion(
"Computer tool is not supported in ChatCompletions API conversion."
.to_string(),
));
}
}
}
}
let tools = if converted_chat_tools.is_empty() {
None
} else {
Some(converted_chat_tools)
};
// Convert input to messages using the shared converter
let converter = ResponsesInputConverter {
input: req.input,
@ -418,57 +654,24 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
service_tier: req.service_tier,
top_logprobs: req.top_logprobs.map(|t| t as u32),
modalities: req.modalities.map(|mods| {
mods.into_iter().map(|m| {
match m {
mods.into_iter()
.map(|m| match m {
Modality::Text => "text".to_string(),
Modality::Audio => "audio".to_string(),
}
}).collect()
})
.collect()
}),
stream_options: req.stream_options.map(|opts| {
crate::apis::openai::StreamOptions {
stream_options: req
.stream_options
.map(|opts| crate::apis::openai::StreamOptions {
include_usage: opts.include_usage,
}
}),
reasoning_effort: req.reasoning_effort.map(|effort| match effort {
ReasoningEffort::Low => "low".to_string(),
ReasoningEffort::Medium => "medium".to_string(),
ReasoningEffort::High => "high".to_string(),
}),
reasoning_effort: req.reasoning_effort.map(|effort| {
match effort {
ReasoningEffort::Low => "low".to_string(),
ReasoningEffort::Medium => "medium".to_string(),
ReasoningEffort::High => "high".to_string(),
}
}),
tools: req.tools.map(|tools| {
tools.into_iter().map(|tool| {
// Only convert Function tools - other types are not supported in ChatCompletions
match tool {
ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool {
tool_type: "function".to_string(),
function: crate::apis::openai::Function {
name,
description,
parameters: parameters.unwrap_or_else(|| serde_json::json!({
"type": "object",
"properties": {}
})),
strict,
}
}),
ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion(
"FileSearch tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::WebSearchPreview { .. } => Err(TransformError::UnsupportedConversion(
"WebSearchPreview tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::CodeInterpreter => Err(TransformError::UnsupportedConversion(
"CodeInterpreter tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
}
}).collect::<Result<Vec<_>, _>>()
}).transpose()?,
tools,
tool_choice: req.tool_choice.map(|choice| {
match choice {
ResponsesToolChoice::String(s) => {
@ -481,11 +684,14 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
}
ResponsesToolChoice::Named { function, .. } => ToolChoice::Function {
choice_type: "function".to_string(),
function: crate::apis::openai::FunctionChoice { name: function.name }
}
function: crate::apis::openai::FunctionChoice {
name: function.name,
},
},
}
}),
parallel_tool_calls: req.parallel_tool_calls,
web_search_options,
..Default::default()
})
}
@ -1027,4 +1233,235 @@ mod tests {
panic!("Expected text content block");
}
}
#[test]
fn test_responses_custom_tool_maps_to_function_tool_for_chat_completions() {
use crate::apis::openai_responses::{
InputParam, ResponsesAPIRequest, Tool as ResponsesTool,
};
let req = ResponsesAPIRequest {
model: "gpt-5.3-codex".to_string(),
input: InputParam::Text("use custom tool".to_string()),
tools: Some(vec![ResponsesTool::Custom {
name: Some("run_patch".to_string()),
description: Some("Apply structured patch".to_string()),
format: Some(serde_json::json!({
"kind": "patch",
"version": "v1"
})),
}]),
include: None,
parallel_tool_calls: None,
store: None,
instructions: None,
stream: None,
stream_options: None,
conversation: None,
tool_choice: None,
max_output_tokens: None,
temperature: None,
top_p: None,
metadata: None,
previous_response_id: None,
modalities: None,
audio: None,
text: None,
reasoning_effort: None,
truncation: None,
user: None,
max_tool_calls: None,
service_tier: None,
background: None,
top_logprobs: None,
};
let converted = ChatCompletionsRequest::try_from(req).expect("conversion should succeed");
let tools = converted.tools.expect("tools should be present");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].tool_type, "function");
assert_eq!(tools[0].function.name, "run_patch");
assert_eq!(
tools[0].function.description.as_deref(),
Some("Apply structured patch")
);
}
#[test]
fn test_responses_web_search_maps_to_chat_web_search_options() {
use crate::apis::openai_responses::{
InputParam, ResponsesAPIRequest, Tool as ResponsesTool, UserLocation,
};
let req = ResponsesAPIRequest {
model: "gpt-5.3-codex".to_string(),
input: InputParam::Text("find project docs".to_string()),
tools: Some(vec![ResponsesTool::WebSearchPreview {
domains: Some(vec!["docs.planoai.dev".to_string()]),
search_context_size: Some("medium".to_string()),
user_location: Some(UserLocation {
location_type: "approximate".to_string(),
city: Some("San Francisco".to_string()),
country: Some("US".to_string()),
region: Some("CA".to_string()),
timezone: Some("America/Los_Angeles".to_string()),
}),
}]),
include: None,
parallel_tool_calls: None,
store: None,
instructions: None,
stream: None,
stream_options: None,
conversation: None,
tool_choice: None,
max_output_tokens: None,
temperature: None,
top_p: None,
metadata: None,
previous_response_id: None,
modalities: None,
audio: None,
text: None,
reasoning_effort: None,
truncation: None,
user: None,
max_tool_calls: None,
service_tier: None,
background: None,
top_logprobs: None,
};
let converted = ChatCompletionsRequest::try_from(req).expect("conversion should succeed");
assert!(converted.web_search_options.is_some());
}
#[test]
fn test_responses_function_call_output_maps_to_tool_message() {
use crate::apis::openai_responses::{
InputItem, InputParam, ResponsesAPIRequest, Tool as ResponsesTool,
};
let req = ResponsesAPIRequest {
model: "gpt-5.3-codex".to_string(),
input: InputParam::Items(vec![InputItem::FunctionCallOutput {
item_type: "function_call_output".to_string(),
call_id: "call_123".to_string(),
output: serde_json::json!({"status":"ok","stdout":"hello"}),
}]),
tools: Some(vec![ResponsesTool::Function {
name: "exec_command".to_string(),
description: Some("Execute a shell command".to_string()),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"cmd": { "type": "string" }
},
"required": ["cmd"]
})),
strict: Some(false),
}]),
include: None,
parallel_tool_calls: None,
store: None,
instructions: None,
stream: None,
stream_options: None,
conversation: None,
tool_choice: None,
max_output_tokens: None,
temperature: None,
top_p: None,
metadata: None,
previous_response_id: None,
modalities: None,
audio: None,
text: None,
reasoning_effort: None,
truncation: None,
user: None,
max_tool_calls: None,
service_tier: None,
background: None,
top_logprobs: None,
};
let converted = ChatCompletionsRequest::try_from(req).expect("conversion should succeed");
assert_eq!(converted.messages.len(), 1);
assert!(matches!(converted.messages[0].role, Role::Tool));
assert_eq!(
converted.messages[0].tool_call_id.as_deref(),
Some("call_123")
);
}
#[test]
fn test_responses_function_call_and_output_preserve_call_id_link() {
use crate::apis::openai_responses::{
InputItem, InputMessage, MessageContent as ResponsesMessageContent, MessageRole,
ResponsesAPIRequest,
};
let req = ResponsesAPIRequest {
model: "gpt-5.3-codex".to_string(),
input: InputParam::Items(vec![
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: ResponsesMessageContent::Items(vec![]),
}),
InputItem::FunctionCall {
item_type: "function_call".to_string(),
name: "exec_command".to_string(),
arguments: "{\"cmd\":\"pwd\"}".to_string(),
call_id: "toolu_abc123".to_string(),
},
InputItem::FunctionCallOutput {
item_type: "function_call_output".to_string(),
call_id: "toolu_abc123".to_string(),
output: serde_json::Value::String("ok".to_string()),
},
]),
tools: None,
include: None,
parallel_tool_calls: None,
store: None,
instructions: None,
stream: None,
stream_options: None,
conversation: None,
tool_choice: None,
max_output_tokens: None,
temperature: None,
top_p: None,
metadata: None,
previous_response_id: None,
modalities: None,
audio: None,
text: None,
reasoning_effort: None,
truncation: None,
user: None,
max_tool_calls: None,
service_tier: None,
background: None,
top_logprobs: None,
};
let converted = ChatCompletionsRequest::try_from(req).expect("conversion should succeed");
assert_eq!(converted.messages.len(), 2);
assert!(matches!(converted.messages[0].role, Role::Assistant));
let tool_calls = converted.messages[0]
.tool_calls
.as_ref()
.expect("assistant tool_calls should be present");
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "toolu_abc123");
assert!(matches!(converted.messages[1].role, Role::Tool));
assert_eq!(
converted.messages[1].tool_call_id.as_deref(),
Some("toolu_abc123")
);
}
}

View file

@ -512,19 +512,12 @@ impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
}
}
// Handle finish_reason - this is a completion signal
// Return an empty delta that the buffer can use to detect completion
// Handle finish_reason - this is a completion signal.
// Emit an explicit Done marker so the buffering layer can finalize
// even if an upstream [DONE] marker is missing/delayed.
if choice.finish_reason.is_some() {
// Return a minimal text delta to signal completion
// The buffer will handle the finish_reason and generate response.completed
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: "".to_string(), // Buffer will fill this
output_index: choice.index as i32,
content_index: 0,
delta: "".to_string(), // Empty delta signals completion
logprobs: vec![],
obfuscation: None,
sequence_number: 0, // Buffer will fill this
return Ok(ResponsesAPIStreamEvent::Done {
sequence_number: 0, // Buffer will assign final sequence
});
}

View file

@ -1046,7 +1046,8 @@ impl HttpContext for StreamContext {
);
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
Ok(request) => {
Ok(mut request) => {
request.normalize_for_upstream(self.get_provider_id(), upstream);
debug!(
"request_id={}: upstream request payload: {}",
self.request_identifier(),