enable state management for v1/responses (#631)

* first commit with tests to enable state mamangement via memory

* fixed logs to follow the conversational flow a bit better

* added support for supabase

* added the state_storage_v1_responses flag, and use that to store state appropriately

* cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo

* fixed mixed inputs from openai v1/responses api (#632)

* fixed mixed inputs from openai v1/responses api

* removing tracing from model-alias-rouing

* handling additional input types from openairs

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>

* resolving PR comments

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
Salman Paracha 2025-12-17 12:18:38 -08:00 committed by GitHub
parent 33e90dd338
commit d5a273f740
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 2687 additions and 76 deletions

View file

@ -1,8 +1,9 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::traces::TraceCollector;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
@ -11,11 +12,16 @@ use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
};
use crate::tracing::operation_component;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
@ -31,14 +37,20 @@ pub async fn llm_chat(
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string());
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = request_headers
.get("traceparent")
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
@ -51,7 +63,8 @@ pub async fn llm_chat(
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
"[PLANO_REQ_ID:{}] | REQUEST_BODY (UTF8): {}",
request_id,
String::from_utf8_lossy(&chat_request_bytes)
);
@ -61,14 +74,19 @@ pub async fn llm_chat(
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
@ -83,9 +101,77 @@ pub async fn llm_chat(
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("Removed archgw_preference_config from metadata");
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
// Get the upstream path and check if it's ResponsesAPI
let upstream_path = get_upstream_path(
&llm_providers,
&resolved_model,
&request_path,
&resolved_model,
is_streaming_request,
).await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
.await
{
Ok(combined_input) => {
// Update both the request and original_input_items
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
info!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Updated request with conversation history ({} items)", request_id, original_input_items.len());
}
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Previous response_id not found: {}", request_id, prev_resp_id);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Conversation state not found for previous_response_id: {}",
request_id, prev_resp_id
);
let mut conflict_response = Response::new(full(err_msg));
*conflict_response.status_mut() = StatusCode::CONFLICT;
return Ok(conflict_response);
}
Err(e) => {
// Log warning but continue on other storage errors
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to retrieve conversation state for {}: {}",
request_id, prev_resp_id, e
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
} else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
}
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
@ -110,8 +196,8 @@ pub async fn llm_chat(
let model_name = routing_result.model_name;
debug!(
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
full_qualified_llm_provider_url, model_name
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
request_id, full_qualified_llm_provider_url, model_name
);
request_headers.insert(
@ -173,15 +259,40 @@ pub async fn llm_chat(
&llm_providers,
).await;
// Use PassthroughProcessor to track streaming metrics and finalize the span
let processor = ObservableStreamProcessor::new(
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
trace_collector,
operation_component::LLM,
llm_span,
request_start_time,
);
let streaming_response = create_streaming_response(byte_stream, processor, 16);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage.unwrap(),
original_input_items,
resolved_model.clone(),
model_name.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id.clone(),
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
@ -301,35 +412,7 @@ async fn get_upstream_path(
resolved_model: &str,
is_streaming: bool,
) -> String {
let providers_lock = llm_providers.read().await;
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
});
let (provider_id, base_url_path_prefix) = if let Some(provider) = provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
};
drop(providers_lock);
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
// Calculate the upstream path using the proper API
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
@ -343,3 +426,37 @@ async fn get_upstream_path(
base_url_path_prefix.as_deref(),
)
}
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
async fn get_provider_info(
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
});
if let Some(provider) = provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

View file

@ -1,4 +1,5 @@
use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER};
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
@ -43,6 +44,10 @@ pub async fn router_chat_get_upstream_model(
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
let request_id = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|value| value.to_str().ok())
.unwrap_or("unknown");
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
@ -73,7 +78,8 @@ pub async fn router_chat_get_upstream_model(
};
debug!(
"[ARCH_ROUTER REQ]: {}",
"[PLANO_REQ_ID: {}]: ROUTER_REQ: {}",
request_id,
&serde_json::to_string(&chat_request).unwrap()
);
@ -114,14 +120,13 @@ pub async fn router_chat_get_upstream_model(
};
info!(
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
"[PLANO_REQ_ID: {}] | ROUTER_REQ | Usage preferences from request: {}, request_path: {}, latest message: {}",
request_id,
usage_preferences.is_some(),
request_path,
latest_message_for_log
);
debug!("usage preferences from request: {:?}", usage_preferences);
// Capture start time for routing span
let routing_start_time = std::time::Instant::now();
let routing_start_system_time = std::time::SystemTime::now();
@ -153,7 +158,8 @@ pub async fn router_chat_get_upstream_model(
None => {
// No route determined, use default model from request
info!(
"No route determined, using default model from request: {}",
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
request_id,
chat_request.model
);