diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 09c86861..69e76fe3 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -308,11 +308,13 @@ name = "brightstaff" version = "0.1.0" dependencies = [ "async-openai", + "async-trait", "bytes", "chrono", "common", "eventsource-client", "eventsource-stream", + "flate2", "futures", "futures-util", "hermesllm", @@ -707,6 +709,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1533,6 +1545,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -2650,6 +2663,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "similar" version = "2.7.0" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 2d88e213..6d5012a7 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -5,11 +5,13 @@ edition = "2021" [dependencies] async-openai = "0.30.1" +async-trait = "0.1" bytes = "1.10.1" chrono = "0.4" common = { version = "0.1.0", path = "../common", features = ["trace-collection"] } eventsource-client = "0.15.0" eventsource-stream = "0.2.3" +flate2 = "1.0" futures = "0.3.31" futures-util = "0.3.31" hermesllm = { version = "0.1.0", path = "../hermesllm" } diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index b3686fae..5e744c8d 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,8 +1,9 @@ use bytes::Bytes; use common::configuration::{LlmProvider, ModelAlias}; -use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; +use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; use common::traces::TraceCollector; -use hermesllm::clients::SupportedAPIsFromClient; +use hermesllm::apis::openai_responses::InputParam; +use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; @@ -16,6 +17,11 @@ use tracing::{debug, warn}; use crate::router::llm_router::RouterService; use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message}; use crate::handlers::router_chat::router_chat_get_upstream_model; +use crate::state::response_state_processor::ResponsesStateProcessor; +use crate::state::{ + StateStorage, StateStorageError, + extract_input_items, retrieve_and_combine_input +}; use crate::tracing::operation_component; fn full>(chunk: T) -> BoxBody { @@ -31,14 +37,20 @@ pub async fn llm_chat( model_aliases: Arc>>, llm_providers: Arc>>, trace_collector: Arc, + state_storage: Arc, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); + let request_id = request_headers + .get(REQUEST_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".to_string()); // Extract or generate traceparent - this establishes the trace context for all spans let traceparent: String = request_headers - .get("traceparent") + .get(TRACE_PARENT_HEADER) .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| { @@ -51,7 +63,8 @@ pub async fn llm_chat( let chat_request_bytes = request.collect().await?.to_bytes(); debug!( - "Received request body (raw utf8): {}", + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | REQUEST BODY (raw utf8): {}", + request_id, String::from_utf8_lossy(&chat_request_bytes) ); @@ -61,14 +74,19 @@ pub async fn llm_chat( )) { Ok(request) => request, Err(err) => { - warn!("Failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); + warn!("[PLANO_REQ_ID:{}] | BRIGHTSTAFF | Failed to parse request as ProviderRequestType: {}", request_id, err); + let err_msg = format!("[PLANO_REQ_ID:{}] | BRIGHTSTAFF | Failed to parse request: {}", request_id, err); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); } }; + // === v1/responses state management: Extract input items early === + let mut original_input_items = Vec::new(); + let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str()); + let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))); + // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model let model_from_request = client_request.model().to_string(); @@ -83,9 +101,76 @@ pub async fn llm_chat( client_request.set_model(resolved_model.clone()); if client_request.remove_metadata_key("archgw_preference_config") { - debug!("Removed archgw_preference_config from metadata"); + debug!("[PLANO (BRIGHTSTAFF)] Removed archgw_preference_config from metadata"); } + // === v1/responses state management: Determine upstream API and combine input if needed === + // Do this BEFORE routing since routing consumes the request + let mut should_manage_state = false; + if is_responses_api_client { + if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request { + // Extract original input once + original_input_items = extract_input_items(&responses_req.input); + + // Get the upstream path and check if it's ResponsesAPI + let upstream_path = get_upstream_path( + &llm_providers, + &resolved_model, + &request_path, + &resolved_model, + is_streaming_request, + ).await; + + let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); + + // Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation) + should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))); + + if should_manage_state { + // Retrieve and combine conversation history if previous_response_id exists + if let Some(ref prev_resp_id) = responses_req.previous_response_id { + match retrieve_and_combine_input( + state_storage.clone(), + prev_resp_id, + original_input_items, // Pass ownership instead of cloning + ) + .await + { + Ok(combined_input) => { + // Update both the request and original_input_items + responses_req.input = InputParam::Items(combined_input.clone()); + original_input_items = combined_input; + debug!("[PLANO (BRIGHTSTAFF)] Updated request with conversation history ({} items)", original_input_items.len()); + } + Err(StateStorageError::NotFound(_)) => { + // Return 409 Conflict when previous_response_id not found + warn!("[PLANO (BRIGHTSTAFF)] Previous response_id not found: {}", prev_resp_id); + let err_msg = format!( + "[PLANO (BRIGHTSTAFF)] Conversation state not found for previous_response_id: {}", + prev_resp_id + ); + let mut conflict_response = Response::new(full(err_msg)); + *conflict_response.status_mut() = StatusCode::CONFLICT; + return Ok(conflict_response); + } + Err(e) => { + // Log warning but continue on other storage errors + warn!( + "Failed to retrieve conversation state for {}: {}", + prev_resp_id, e + ); + // Restore original_input_items since we passed ownership + original_input_items = extract_input_items(&responses_req.input); + } + } + } + } else { + debug!("[PLANO (BRIGHTSTAFF)] Upstream supports ResponsesAPI natively, passing through without state management"); + } + } + } + + // Serialize request for upstream BEFORE router consumes it let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap(); // Determine routing using the dedicated router_chat module @@ -110,7 +195,7 @@ pub async fn llm_chat( let model_name = routing_result.model_name; debug!( - "[ARCH_ROUTER] URL: {}, Resolved Model: {}", + "[PLANO ARCH_ROUTER] URL: {}, Resolved Model: {}", full_qualified_llm_provider_url, model_name ); @@ -173,15 +258,40 @@ pub async fn llm_chat( &llm_providers, ).await; - // Use PassthroughProcessor to track streaming metrics and finalize the span - let processor = ObservableStreamProcessor::new( + // Create base processor for metrics and tracing + let base_processor = ObservableStreamProcessor::new( trace_collector, operation_component::LLM, llm_span, request_start_time, ); - let streaming_response = create_streaming_response(byte_stream, processor, 16); + // === v1/responses state management: Wrap with ResponsesStateProcessor === + // Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI) + let streaming_response = if should_manage_state && !original_input_items.is_empty() { + // Extract Content-Encoding header to handle decompression for state parsing + let content_encoding = response_headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + // Wrap with state management processor to store state after response completes + let state_processor = ResponsesStateProcessor::new( + base_processor, + state_storage, + original_input_items, + resolved_model.clone(), + model_name.clone(), + is_streaming_request, + false, // Not OpenAI upstream since should_manage_state is true + content_encoding, + request_id.clone(), + ); + create_streaming_response(byte_stream, state_processor, 16) + } else { + // Use base processor without state management + create_streaming_response(byte_stream, base_processor, 16) + }; match response.body(streaming_response.body) { Ok(response) => Ok(response), @@ -301,35 +411,7 @@ async fn get_upstream_path( resolved_model: &str, is_streaming: bool, ) -> String { - let providers_lock = llm_providers.read().await; - - // First, try to find by model name or provider name - let provider = providers_lock.iter().find(|p| { - p.model.as_ref().map(|m| m == model_name).unwrap_or(false) - || p.name == model_name - }); - - let (provider_id, base_url_path_prefix) = if let Some(provider) = provider { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - (provider_id, prefix) - } else { - let default_provider = providers_lock.iter().find(|p| { - p.default.unwrap_or(false) - }); - - if let Some(provider) = default_provider { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - (provider_id, prefix) - } else { - // Last resort: use OpenAI as hardcoded fallback - warn!("No default provider found, falling back to OpenAI"); - (hermesllm::ProviderId::OpenAI, None) - } - }; - - drop(providers_lock); + let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; // Calculate the upstream path using the proper API let client_api = SupportedAPIsFromClient::from_endpoint(request_path) @@ -343,3 +425,37 @@ async fn get_upstream_path( base_url_path_prefix.as_deref(), ) } + +/// Helper function to get provider info (ProviderId and base_url_path_prefix) +async fn get_provider_info( + llm_providers: &Arc>>, + model_name: &str, +) -> (hermesllm::ProviderId, Option) { + let providers_lock = llm_providers.read().await; + + // First, try to find by model name or provider name + let provider = providers_lock.iter().find(|p| { + p.model.as_ref().map(|m| m == model_name).unwrap_or(false) + || p.name == model_name + }); + + if let Some(provider) = provider { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + return (provider_id, prefix); + } + + let default_provider = providers_lock.iter().find(|p| { + p.default.unwrap_or(false) + }); + + if let Some(provider) = default_provider { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + (provider_id, prefix) + } else { + // Last resort: use OpenAI as hardcoded fallback + warn!("No default provider found, falling back to OpenAI"); + (hermesllm::ProviderId::OpenAI, None) + } +} diff --git a/crates/brightstaff/src/lib.rs b/crates/brightstaff/src/lib.rs index ceff49f1..36fc902f 100644 --- a/crates/brightstaff/src/lib.rs +++ b/crates/brightstaff/src/lib.rs @@ -1,4 +1,5 @@ pub mod handlers; pub mod router; +pub mod state; pub mod tracing; pub mod utils; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index d0241fa3..78be13a5 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -3,6 +3,8 @@ use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::function_calling::{function_calling_chat_handler}; use brightstaff::router::llm_router::RouterService; +use brightstaff::state::memory::MemoryConversationalStorage; +use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::Configuration; @@ -101,6 +103,11 @@ async fn main() -> Result<(), Box> { let trace_collector = Arc::new(TraceCollector::new(tracing_enabled)); let _flusher_handle = trace_collector.clone().start_background_flusher(); + // Initialize conversation state storage for v1/responses + // TODO: Make this configurable (MEMORY vs SUPABASE) via arch_config.yaml + let state_storage: Arc = Arc::new(MemoryConversationalStorage::new()); + info!("Initialized conversation state storage: Memory"); + loop { let (stream, _) = listener.accept().await?; @@ -115,6 +122,7 @@ async fn main() -> Result<(), Box> { let agents_list = agents_list.clone(); let listeners = listeners.clone(); let trace_collector = trace_collector.clone(); + let state_storage = state_storage.clone(); let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); @@ -124,13 +132,14 @@ async fn main() -> Result<(), Box> { let agents_list = agents_list.clone(); let listeners = listeners.clone(); let trace_collector = trace_collector.clone(); + let state_storage = state_storage.clone(); async move { match (req.method(), req.uri().path()) { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); - llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector) + llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage) .with_context(parent_cx) .await } diff --git a/crates/brightstaff/src/state/memory.rs b/crates/brightstaff/src/state/memory.rs new file mode 100644 index 00000000..3cd5e39c --- /dev/null +++ b/crates/brightstaff/src/state/memory.rs @@ -0,0 +1,584 @@ +use super::{OpenAIConversationState, StateStorage, StateStorageError}; +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, warn}; + +/// In-memory storage backend for conversation state +/// Uses a HashMap wrapped in Arc> for thread-safe access +#[derive(Clone)] +pub struct MemoryConversationalStorage { + storage: Arc>>, +} + +impl MemoryConversationalStorage { + pub fn new() -> Self { + Self { + storage: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +impl Default for MemoryConversationalStorage { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl StateStorage for MemoryConversationalStorage { + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> { + let response_id = state.response_id.clone(); + let mut storage = self.storage.write().await; + + debug!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}", + response_id, state.model, state.provider, state.input_items.len() + ); + + storage.insert(response_id, state); + Ok(()) + } + + async fn get(&self, response_id: &str) -> Result { + let storage = self.storage.read().await; + + match storage.get(response_id) { + Some(state) => { + debug!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Retrieved conversation state: input_items={}", + response_id, state.input_items.len() + ); + Ok(state.clone()) + } + None => { + warn!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Conversation state not found", + response_id + ); + Err(StateStorageError::NotFound(response_id.to_string())) + } + } + } + + async fn exists(&self, response_id: &str) -> Result { + let storage = self.storage.read().await; + Ok(storage.contains_key(response_id)) + } + + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> { + let mut storage = self.storage.write().await; + + if storage.remove(response_id).is_some() { + debug!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state", + response_id + ); + Ok(()) + } else { + Err(StateStorageError::NotFound(response_id.to_string())) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent}; + + fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { + let mut input_items = Vec::new(); + for i in 0..num_messages { + input_items.push(InputItem::Message(InputMessage { + role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, + content: vec![InputContent::InputText { + text: format!("Message {}", i), + }], + })); + } + + OpenAIConversationState { + response_id: response_id.to_string(), + input_items, + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + } + } + + #[tokio::test] + async fn test_put_and_get_success() { + let storage = MemoryConversationalStorage::new(); + let state: OpenAIConversationState = create_test_state("resp_001", 3); + + // Store + storage.put(state.clone()).await.unwrap(); + + // Retrieve + let retrieved = storage.get("resp_001").await.unwrap(); + assert_eq!(retrieved.response_id, state.response_id); + assert_eq!(retrieved.model, state.model); + assert_eq!(retrieved.provider, state.provider); + assert_eq!(retrieved.input_items.len(), 3); + assert_eq!(retrieved.created_at, state.created_at); + } + + #[tokio::test] + async fn test_put_overwrites_existing() { + let storage = MemoryConversationalStorage::new(); + + // First state + let state1 = create_test_state("resp_002", 2); + storage.put(state1).await.unwrap(); + + // Overwrite with new state + let state2 = OpenAIConversationState { + response_id: "resp_002".to_string(), + input_items: vec![], + created_at: 9999999999, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + storage.put(state2.clone()).await.unwrap(); + + // Should retrieve the new state + let retrieved = storage.get("resp_002").await.unwrap(); + assert_eq!(retrieved.model, "gpt-4"); + assert_eq!(retrieved.provider, "openai"); + assert_eq!(retrieved.input_items.len(), 0); + assert_eq!(retrieved.created_at, 9999999999); + } + + #[tokio::test] + async fn test_get_not_found() { + let storage = MemoryConversationalStorage::new(); + + let result = storage.get("nonexistent").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::NotFound(id) => { + assert_eq!(id, "nonexistent"); + } + _ => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_exists_returns_false_for_nonexistent() { + let storage = MemoryConversationalStorage::new(); + assert!(!storage.exists("resp_003").await.unwrap()); + } + + #[tokio::test] + async fn test_exists_returns_true_after_put() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_004", 1); + + assert!(!storage.exists("resp_004").await.unwrap()); + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_004").await.unwrap()); + } + + #[tokio::test] + async fn test_delete_success() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_005", 2); + + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_005").await.unwrap()); + + // Delete + storage.delete("resp_005").await.unwrap(); + + // Should no longer exist + assert!(!storage.exists("resp_005").await.unwrap()); + assert!(storage.get("resp_005").await.is_err()); + } + + #[tokio::test] + async fn test_delete_not_found() { + let storage = MemoryConversationalStorage::new(); + + let result = storage.delete("nonexistent").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::NotFound(id) => { + assert_eq!(id, "nonexistent"); + } + _ => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_merge_combines_inputs() { + let storage = MemoryConversationalStorage::new(); + + // Create a previous state with 2 messages + let prev_state = create_test_state("resp_006", 2); + + // Create current input with 1 message + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "New message".to_string(), + }], + })]; + + // Merge + let merged = storage.merge(&prev_state, current_input); + + // Should have 3 messages total (2 from prev + 1 current) + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_preserves_order() { + let storage = MemoryConversationalStorage::new(); + + // Previous state has messages 0 and 1 + let prev_state = create_test_state("resp_007", 2); + + // Current input has message 2 + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Message 2".to_string(), + }], + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Verify order: prev messages first, then current + let InputItem::Message(msg) = &merged[0]; + match &msg.content[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 0"), + _ => panic!("Expected InputText"), + } + + let InputItem::Message(msg) = &merged[2]; + match &msg.content[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 2"), + _ => panic!("Expected InputText"), + } + } + + #[tokio::test] + async fn test_merge_with_empty_current_input() { + let storage = MemoryConversationalStorage::new(); + let prev_state = create_test_state("resp_008", 3); + + let merged = storage.merge(&prev_state, vec![]); + + // Should just have the previous state's items + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_with_empty_previous_state() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_009".to_string(), + input_items: vec![], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Only message".to_string(), + }], + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should just have the current input + assert_eq!(merged.len(), 1); + } + + #[tokio::test] + async fn test_concurrent_access() { + let storage = MemoryConversationalStorage::new(); + + // Spawn multiple tasks that write concurrently + let mut handles = vec![]; + + for i in 0..10 { + let storage_clone = storage.clone(); + let handle = tokio::spawn(async move { + let state = create_test_state(&format!("resp_{}", i), i % 3); + storage_clone.put(state).await.unwrap(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Verify all states were stored + for i in 0..10 { + assert!(storage.exists(&format!("resp_{}", i)).await.unwrap()); + } + } + + #[tokio::test] + async fn test_multiple_operations_on_same_id() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_010", 1); + + // Put + storage.put(state.clone()).await.unwrap(); + + // Get + let retrieved = storage.get("resp_010").await.unwrap(); + assert_eq!(retrieved.response_id, "resp_010"); + + // Exists + assert!(storage.exists("resp_010").await.unwrap()); + + // Put again (overwrite) + let new_state = create_test_state("resp_010", 5); + storage.put(new_state).await.unwrap(); + + // Get updated + let updated = storage.get("resp_010").await.unwrap(); + assert_eq!(updated.input_items.len(), 5); + + // Delete + storage.delete("resp_010").await.unwrap(); + + // Should not exist + assert!(!storage.exists("resp_010").await.unwrap()); + } + + #[tokio::test] + async fn test_merge_with_tool_call_flow() { + // This test simulates a realistic tool call conversation flow: + // 1. User sends message: "What's the weather?" + // 2. Model responds with function call (converted to assistant message) + // 3. User sends function call output in next request with previous_response_id + // The merge should combine: user message + assistant function call + function output + + let storage = MemoryConversationalStorage::new(); + + // Step 1: Previous state contains the initial exchange + // - User message: "What's the weather in SF?" + // - Assistant message (converted from FunctionCall): "Called function: get_weather..." + let prev_state = OpenAIConversationState { + response_id: "resp_tool_001".to_string(), + input_items: vec![ + // Original user message + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "What's the weather in San Francisco?".to_string(), + }], + }), + // Assistant's function call (converted from OutputItem::FunctionCall) + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(), + }], + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + // Step 2: Current request includes function call output + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(), + }], + })]; + + // Step 3: Merge should combine all conversation history + let merged = storage.merge(&prev_state, current_input); + + // Should have 3 items: user question + assistant function call + function output + assert_eq!(merged.len(), 3); + + // Verify the order and content + let InputItem::Message(msg1) = &merged[0]; + assert!(matches!(msg1.role, MessageRole::User)); + match &msg1.content[0] { + InputContent::InputText { text } => { + assert!(text.contains("weather in San Francisco")); + } + _ => panic!("Expected InputText"), + } + + let InputItem::Message(msg2) = &merged[1]; + assert!(matches!(msg2.role, MessageRole::Assistant)); + match &msg2.content[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); + } + _ => panic!("Expected InputText"), + } + + let InputItem::Message(msg3) = &merged[2]; + assert!(matches!(msg3.role, MessageRole::User)); + match &msg3.content[0] { + InputContent::InputText { text } => { + assert!(text.contains("Function result")); + assert!(text.contains("temperature")); + } + _ => panic!("Expected InputText"), + } + } + + #[tokio::test] + async fn test_merge_with_multiple_tool_calls() { + // Test a more complex scenario with multiple tool calls + let storage = MemoryConversationalStorage::new(); + + // Previous state has: user message + 2 function calls from assistant + let prev_state = OpenAIConversationState { + response_id: "resp_tool_002".to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "What's the weather and time in SF?".to_string(), + }], + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(), + }], + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(), + }], + }), + ], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + // Current input: function outputs for both calls + let current_input = vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Weather result: {\"temp\": 68}".to_string(), + }], + }), + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Time result: {\"time\": \"14:30\"}".to_string(), + }], + }), + ]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs + assert_eq!(merged.len(), 5); + + // Verify first item is original user message + let InputItem::Message(first) = &merged[0]; + assert!(matches!(first.role, MessageRole::User)); + + // Verify last two are function outputs + let InputItem::Message(second_last) = &merged[3]; + assert!(matches!(second_last.role, MessageRole::User)); + match &second_last.content[0] { + InputContent::InputText { text } => assert!(text.contains("Weather result")), + _ => panic!("Expected InputText"), + } + + let InputItem::Message(last) = &merged[4]; + assert!(matches!(last.role, MessageRole::User)); + match &last.content[0] { + InputContent::InputText { text } => assert!(text.contains("Time result")), + _ => panic!("Expected InputText"), + } + } + + #[tokio::test] + async fn test_merge_preserves_conversation_context_for_multi_turn() { + // Simulate a multi-turn conversation with tool calls + let storage = MemoryConversationalStorage::new(); + + // Previous state: full conversation history up to this point + let prev_state = OpenAIConversationState { + response_id: "resp_tool_003".to_string(), + input_items: vec![ + // Turn 1: User asks about weather + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "What's the weather?".to_string(), + }], + }), + // Turn 1: Assistant calls get_weather + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: "Called function: get_weather".to_string(), + }], + }), + // Turn 2: User provides function output + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Weather: sunny, 72°F".to_string(), + }], + }), + // Turn 2: Assistant responds with text + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: "It's sunny and 72°F in San Francisco today!".to_string(), + }], + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + // Turn 3: User asks follow-up question + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Should I bring an umbrella?".to_string(), + }], + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have all 5 messages in order + assert_eq!(merged.len(), 5); + + // Verify the entire conversation flow is preserved + let InputItem::Message(first) = &merged[0]; + match &first.content[0] { + InputContent::InputText { text } => assert!(text.contains("What's the weather")), + _ => panic!("Expected InputText"), + } + + let InputItem::Message(last) = &merged[4]; + match &last.content[0] { + InputContent::InputText { text } => assert!(text.contains("umbrella")), + _ => panic!("Expected InputText"), + } + } +} diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs new file mode 100644 index 00000000..2eedae6f --- /dev/null +++ b/crates/brightstaff/src/state/mod.rs @@ -0,0 +1,157 @@ +use async_trait::async_trait; +use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageRole, InputParam}; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use tracing::{debug, info}; + +pub mod memory; +pub mod response_state_processor; +pub mod supabase; + +/// Represents the conversational state for a v1/responses request +/// Contains the complete input/output history that can be restored +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAIConversationState { + /// The response ID this state is associated with + pub response_id: String, + + /// The complete input history (original input + accumulated outputs) + /// This is what gets prepended to new requests via previous_response_id + pub input_items: Vec, + + /// Timestamp when this state was created + pub created_at: i64, + + /// Model used for this response + pub model: String, + + /// Provider that generated this response (e.g., "anthropic", "openai") + pub provider: String, +} + +/// Error types for state storage operations +#[derive(Debug)] +pub enum StateStorageError { + /// State not found for given response_id + NotFound(String), + + /// Storage backend error (network, database, etc.) + StorageError(String), + + /// Serialization/deserialization error + SerializationError(String), +} + +impl fmt::Display for StateStorageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id), + StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg), + StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg), + } + } +} + +impl Error for StateStorageError {} + +/// Trait for conversation state storage backends +#[async_trait] +pub trait StateStorage: Send + Sync { + /// Store conversation state for a response + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>; + + /// Retrieve conversation state by response_id + async fn get(&self, response_id: &str) -> Result; + + /// Check if state exists for a response_id + async fn exists(&self, response_id: &str) -> Result; + + /// Delete state for a response_id (optional, for cleanup) + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>; + + fn merge( + &self, + prev_state: &OpenAIConversationState, + current_input: Vec, + ) -> Vec { + // Default implementation: prepend previous input, append current + let prev_count = prev_state.input_items.len(); + let current_count = current_input.len(); + + let mut combined_input = prev_state.input_items.clone(); + combined_input.extend(current_input); + + debug!( + "PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}", + prev_state.response_id, + prev_count, + current_count, + combined_input.len(), + serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + ); + + combined_input + } +} + + + +/// Storage backend type enum +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StorageBackend { + Memory, + Supabase, +} + +impl StorageBackend { + pub fn from_str(s: &str) -> Option { + match s.to_lowercase().as_str() { + "memory" => Some(StorageBackend::Memory), + "supabase" => Some(StorageBackend::Supabase), + _ => None, + } + } +} + +// === Utility functions for state management === + +/// Extract input items from InputParam, converting text to structured format +pub fn extract_input_items(input: &InputParam) -> Vec { + match input { + InputParam::Text(text) => { + vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: text.clone(), + }], + })] + } + InputParam::Items(items) => items.clone(), + } +} + +/// Retrieve previous conversation state and combine with current input +/// Returns combined input if previous state found, or original input if not found/error +pub async fn retrieve_and_combine_input( + storage: Arc, + previous_response_id: &str, + current_input: Vec, +) -> Result, StateStorageError> { + info!( + "Retrieving conversation state for previous_response_id: {}", + previous_response_id + ); + + // First get the previous state + let prev_state = storage.get(previous_response_id).await?; + let combined_input = storage.merge(&prev_state, current_input); + + debug!( + "Retrieved and merged conversation state: {} total input items", + combined_input.len() + ); + + Ok(combined_input) +} diff --git a/crates/brightstaff/src/state/response_state_processor.rs b/crates/brightstaff/src/state/response_state_processor.rs new file mode 100644 index 00000000..c634df4e --- /dev/null +++ b/crates/brightstaff/src/state/response_state_processor.rs @@ -0,0 +1,307 @@ +use bytes::Bytes; +use flate2::read::GzDecoder; +use hermesllm::apis::openai_responses::{ + InputItem, OutputItem, ResponsesAPIStreamEvent, +}; +use hermesllm::apis::streaming_shapes::sse::SseStreamIter; +use hermesllm::transforms::response::output_to_input::outputs_to_inputs; +use std::io::Read; +use std::sync::Arc; +use tracing::{info, debug, warn}; + +use crate::handlers::utils::StreamProcessor; +use crate::state::{OpenAIConversationState, StateStorage}; + +/// Processor that wraps another processor and handles v1/responses state management +/// Captures response_id and output from streaming responses, stores state after completion +pub struct ResponsesStateProcessor { + /// The underlying processor (e.g., ObservableStreamProcessor for metrics) + inner: P, + + /// State storage backend + storage: Arc, + + /// Original input items from the request + original_input: Vec, + + /// Model name + model: String, + + /// Provider name + provider: String, + + /// Whether this is a streaming request + is_streaming: bool, + + /// Whether upstream is OpenAI (skip storage if true) + is_openai_upstream: bool, + + /// Content-Encoding header value (e.g., "gzip", "br", None) + content_encoding: Option, + + /// Request ID for logging + request_id: String, + + /// Buffer for accumulating chunks (needed for non-streaming compressed responses) + chunk_buffer: Vec, + + /// Captured response_id from response.completed event + response_id: Option, + + /// Captured output items from response.completed event + output_items: Option>, +} + +impl ResponsesStateProcessor

{ + pub fn new( + inner: P, + storage: Arc, + original_input: Vec, + model: String, + provider: String, + is_streaming: bool, + is_openai_upstream: bool, + content_encoding: Option, + request_id: String, + ) -> Self { + Self { + inner, + storage, + original_input, + model, + provider, + is_streaming, + is_openai_upstream, + content_encoding, + request_id, + chunk_buffer: Vec::new(), + response_id: None, + output_items: None, + } + } + + /// Decompress accumulated buffer based on Content-Encoding header + fn decompress_buffer(&self) -> Vec { + if self.chunk_buffer.is_empty() { + return Vec::new(); + } + + match self.content_encoding.as_deref() { + Some("gzip") => { + let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice()); + let mut decompressed = Vec::new(); + match decoder.read_to_end(&mut decompressed) { + Ok(_) => { + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes", + self.request_id, + self.chunk_buffer.len(), + decompressed.len() + ); + decompressed + } + Err(e) => { + warn!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to decompress gzip buffer: {}", + self.request_id, + e + ); + self.chunk_buffer.clone() + } + } + } + Some(encoding) => { + warn!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.", + self.request_id, + encoding + ); + self.chunk_buffer.clone() + } + None => self.chunk_buffer.clone(), + } + } + + /// Parse response to extract response_id and output + /// For streaming: parse SSE events looking for response.completed (per chunk) + /// For non-streaming: buffer all chunks, then decompress and parse on completion + fn try_parse_response_chunk(&mut self, chunk: &[u8]) { + if self.is_streaming { + // Streaming: Try to parse SSE events from this chunk + // Note: For compressed streaming, we'd need to buffer and decompress first + // but most streaming responses aren't compressed since SSE needs to be readable + let sse_iter = match SseStreamIter::try_from(chunk) { + Ok(iter) => iter, + Err(_) => return, // Not valid SSE format, skip + }; + + // Process each SSE event in the chunk, looking for data lines with response.completed + for event in sse_iter { + // Only process data lines (skip event-only lines) + if let Some(data_str) = &event.data { + // Try to parse as ResponsesAPIStreamEvent + if let Ok(stream_event) = serde_json::from_str::(data_str) { + // Check if this is a ResponseCompleted event + if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event { + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}, output_json={}", + self.request_id, + response.id, + response.output.len(), + serde_json::to_string(&response.output).unwrap_or_else(|_| "serialization_error".to_string()) + ); + self.response_id = Some(response.id.clone()); + self.output_items = Some(response.output.clone()); + return; // Found what we need, exit early + } + } + } + } + } else { + // Non-streaming: Buffer chunks, will decompress and parse on completion + self.chunk_buffer.extend_from_slice(chunk); + } + } + + /// Parse buffered non-streaming response (called on completion) + fn try_parse_buffered_response(&mut self) { + if self.is_streaming || self.chunk_buffer.is_empty() { + return; + } + + // Decompress if needed + let decompressed = self.decompress_buffer(); + + // Parse complete JSON response + match serde_json::from_slice::(&decompressed) { + Ok(response) => { + info!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}", + self.request_id, + response.id, + response.output.len() + ); + self.response_id = Some(response.id.clone()); + self.output_items = Some(response.output.clone()); + } + Err(e) => { + // Log parse error with chunk preview for debugging + let chunk_preview = String::from_utf8_lossy(&decompressed); + let preview_len = chunk_preview.len().min(200); + warn!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}", + self.request_id, + e, + preview_len, + &chunk_preview[..preview_len] + ); + } + } + } +} + +impl StreamProcessor for ResponsesStateProcessor

{ + fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { + // Buffer/parse chunk for response extraction + self.try_parse_response_chunk(&chunk); + + // Forward to inner processor + self.inner.process_chunk(chunk) + } + + fn on_first_bytes(&mut self) { + self.inner.on_first_bytes(); + } + + fn on_complete(&mut self) { + // For non-streaming, decompress and parse buffered response + self.try_parse_buffered_response(); + + // First, let the inner processor complete + self.inner.on_complete(); + + // Skip storage for OpenAI upstream + if self.is_openai_upstream { + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider", + self.request_id + ); + return; + } + + // Store state if we captured response_id and output + if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) { + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Output items before conversion: {}", + self.request_id, + serde_json::to_string(&output_items).unwrap_or_else(|_| "serialization_error".to_string()) + ); + + // Convert output items to input items for next request + let output_as_inputs = outputs_to_inputs(output_items); + + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}", + self.request_id, output_items.len(), output_as_inputs.len() + ); + + // Combine original input + output as new input history + let mut combined_input = self.original_input.clone(); + combined_input.extend(output_as_inputs); + + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}", + self.request_id, + self.original_input.len(), + combined_input.len(), + serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + ); + + let state = OpenAIConversationState { + response_id: response_id.clone(), + input_items: combined_input, + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64, + model: self.model.clone(), + provider: self.provider.clone(), + }; + + // Store asynchronously (fire and forget with logging) + let storage = self.storage.clone(); + let response_id_clone = response_id.clone(); + let request_id = self.request_id.clone(); + tokio::spawn(async move { + match storage.put(state).await { + Ok(()) => { + debug!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}", + request_id, + response_id_clone + ); + } + Err(e) => { + warn!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}", + request_id, + response_id_clone, + e + ); + } + } + }); + } else { + warn!( + "[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}", + self.request_id, + self.response_id.is_some(), + self.output_items.is_some() + ); + } + } + + fn on_error(&mut self, error: &str) { + self.inner.on_error(error); + } +} diff --git a/crates/brightstaff/src/state/supabase.rs b/crates/brightstaff/src/state/supabase.rs new file mode 100644 index 00000000..71c032ad --- /dev/null +++ b/crates/brightstaff/src/state/supabase.rs @@ -0,0 +1,223 @@ +use super::{OpenAIConversationState, StateStorage, StateStorageError}; +use async_trait::async_trait; +use tracing::{debug, warn}; + +/// Supabase/PostgreSQL storage backend for conversation state +/// This is a placeholder implementation that can be extended with actual PostgreSQL logic +#[derive(Clone)] +pub struct SupabaseConversationalStorage { + // Connection pool or client would go here + // e.g., sqlx::PgPool or tokio_postgres::Client + _connection_string: String, +} + +impl SupabaseConversationalStorage { + pub fn new(connection_string: String) -> Self { + Self { + _connection_string: connection_string, + } + } +} + +#[async_trait] +impl StateStorage for SupabaseConversationalStorage { + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> { + warn!( + "Supabase storage not yet implemented - would store response_id: {}", + state.response_id + ); + + // TODO: Implement PostgreSQL storage + // SQL: INSERT INTO conversation_states (response_id, input_items, created_at, model, provider) + // VALUES ($1, $2, $3, $4, $5) + // ON CONFLICT (response_id) DO UPDATE SET ... + + Err(StateStorageError::StorageError( + "Supabase storage not yet implemented".to_string(), + )) + } + + async fn get(&self, response_id: &str) -> Result { + warn!( + "Supabase storage not yet implemented - would retrieve response_id: {}", + response_id + ); + + // TODO: Implement PostgreSQL retrieval + // SQL: SELECT * FROM conversation_states WHERE response_id = $1 + + Err(StateStorageError::StorageError( + "Supabase storage not yet implemented".to_string(), + )) + } + + async fn exists(&self, response_id: &str) -> Result { + debug!("Checking existence for response_id: {}", response_id); + + // TODO: Implement PostgreSQL existence check + // SQL: SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1) + + Err(StateStorageError::StorageError( + "Supabase storage not yet implemented".to_string(), + )) + } + + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> { + debug!("Deleting response_id: {}", response_id); + + // TODO: Implement PostgreSQL deletion + // SQL: DELETE FROM conversation_states WHERE response_id = $1 + + Err(StateStorageError::StorageError( + "Supabase storage not yet implemented".to_string(), + )) + } +} + +/* +Suggested PostgreSQL schema: + +CREATE TABLE conversation_states ( + response_id TEXT PRIMARY KEY, + input_items JSONB NOT NULL, + created_at BIGINT NOT NULL, + model TEXT NOT NULL, + provider TEXT NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_conversation_states_created_at ON conversation_states(created_at); +CREATE INDEX idx_conversation_states_provider ON conversation_states(provider); +*/ + +#[cfg(test)] +mod tests { + use super::*; + use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent}; + + fn create_test_state(response_id: &str) -> OpenAIConversationState { + OpenAIConversationState { + response_id: response_id.to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "Test message".to_string(), + }], + }), + ], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + } + } + + // These tests validate the current "not implemented" behavior + // Once the Supabase implementation is complete with actual PostgreSQL integration, + // these should be replaced with comprehensive tests similar to memory.rs + + #[tokio::test] + async fn test_supabase_put_returns_not_implemented() { + let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string()); + let state = create_test_state("resp_001"); + + let result = storage.put(state).await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::StorageError(msg) => { + assert!(msg.contains("not yet implemented")); + } + _ => panic!("Expected StorageError"), + } + } + + #[tokio::test] + async fn test_supabase_get_returns_not_implemented() { + let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string()); + + let result = storage.get("resp_002").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::StorageError(msg) => { + assert!(msg.contains("not yet implemented")); + } + _ => panic!("Expected StorageError"), + } + } + + #[tokio::test] + async fn test_supabase_exists_returns_not_implemented() { + let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string()); + + let result = storage.exists("resp_003").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::StorageError(msg) => { + assert!(msg.contains("not yet implemented")); + } + _ => panic!("Expected StorageError"), + } + } + + #[tokio::test] + async fn test_supabase_delete_returns_not_implemented() { + let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string()); + + let result = storage.delete("resp_004").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::StorageError(msg) => { + assert!(msg.contains("not yet implemented")); + } + _ => panic!("Expected StorageError"), + } + } + + #[tokio::test] + async fn test_supabase_merge_works() { + // merge() is implemented in the trait default, so it should work even without DB + let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string()); + + let prev_state = create_test_state("resp_005"); + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: vec![InputContent::InputText { + text: "New message".to_string(), + }], + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have 2 messages (1 from prev + 1 current) + assert_eq!(merged.len(), 2); + } + + /* TODO: Add comprehensive tests when SupabaseConversationalStorage is implemented + * + * Once the actual PostgreSQL integration is complete, add tests similar to those + * in memory.rs, including: + * + * - test_supabase_put_and_get_success: Store and retrieve state + * - test_supabase_put_overwrites_existing: Verify upsert behavior + * - test_supabase_get_not_found: Check NotFound error handling + * - test_supabase_exists_returns_false: Test non-existent ID + * - test_supabase_exists_returns_true_after_put: Verify existence after insert + * - test_supabase_delete_success: Delete and verify removal + * - test_supabase_delete_not_found: Delete non-existent ID + * - test_supabase_merge_various_scenarios: Test merge with different input combinations + * - test_supabase_concurrent_access: Test with multiple concurrent operations + * - test_supabase_serialization: Verify JSON serialization of input_items + * - test_supabase_connection_failure: Handle connection errors + * - test_supabase_invalid_data: Handle malformed JSON in database + * + * Test setup would require: + * - Test database setup/teardown (perhaps using testcontainers-rs or docker) + * - Connection pool initialization + * - Table creation before tests + * - Data cleanup between tests + */ +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs index 84854af3..ca8a9cfd 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -59,6 +59,11 @@ pub struct ResponsesAPIStreamBuffer { model: Option, created_at: Option, + /// Full response metadata from upstream (tools, temperature, etc.) + /// This is extracted from the first upstream event and used to build + /// complete response.created and response.in_progress events + upstream_response_metadata: Option, + /// Lifecycle state flags created_emitted: bool, in_progress_emitted: bool, @@ -88,6 +93,7 @@ impl ResponsesAPIStreamBuffer { response_id: None, model: None, created_at: None, + upstream_response_metadata: None, created_emitted: false, in_progress_emitted: false, output_items_added: HashMap::new(), @@ -171,6 +177,15 @@ impl ResponsesAPIStreamBuffer { /// Build the base response object with current state fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse { + // If we have upstream metadata, use it as a base and update status/output + if let Some(upstream) = &self.upstream_response_metadata { + let mut response = upstream.clone(); + response.status = status; + // Don't update output here - will be set in finalize() + return response; + } + + // Fallback: build a minimal response from local state ResponsesAPIResponse { id: self.response_id.clone().unwrap_or_default(), object: "response".to_string(), @@ -293,24 +308,40 @@ impl ResponsesAPIStreamBuffer { // Build final response let mut output_items = Vec::new(); - // Add tool calls to output - for (item_id, arguments) in &self.function_arguments { - let output_index = self.output_items_added.iter() - .find(|(_, id)| *id == item_id) - .map(|(idx, _)| *idx) - .unwrap_or(0); + // Build complete output array by iterating through all output indices in order + let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1); - let (call_id, name) = self.tool_call_metadata.get(&output_index) - .cloned() - .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + for output_index in 0..=max_output_index { + if let Some(item_id) = self.output_items_added.get(&output_index) { + // Check if this is a function call + if let Some(arguments) = self.function_arguments.get(item_id) { + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); - output_items.push(OutputItem::FunctionCall { - id: item_id.clone(), - status: OutputItemStatus::Completed, - call_id, - name: Some(name), - arguments: Some(arguments.clone()), - }); + output_items.push(OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }); + } + // Check if this is a text message + else if let Some(text) = self.text_content.get(item_id) { + use crate::apis::openai_responses::OutputContent; + output_items.push(OutputItem::Message { + id: item_id.clone(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: text.clone(), + annotations: vec![], + logprobs: None, + }], + }); + } + } } let mut final_response = self.build_response(ResponseStatus::Completed); @@ -365,6 +396,24 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { let mut events = Vec::new(); + // Capture upstream metadata from ResponseCreated or ResponseInProgress if present + match stream_event { + ResponsesAPIStreamEvent::ResponseCreated { response, .. } | + ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => { + if self.upstream_response_metadata.is_none() { + // Store the full upstream response as our metadata template + self.upstream_response_metadata = Some(response.clone()); + // Also extract basic fields + self.response_id = Some(response.id.clone()); + self.model = Some(response.model.clone()); + self.created_at = Some(response.created_at); + } + // Don't emit these - we'll generate our own lifecycle events + return; + } + _ => {} + } + // Emit lifecycle events if not yet emitted if !self.created_emitted { // Initialize metadata from first event if needed diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 09ab262d..5a923329 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -193,6 +193,40 @@ impl SupportedAPIsFromClient { } } + +impl SupportedUpstreamAPIs { + /// Create a SupportedUpstreamApi from an endpoint path + pub fn from_endpoint(endpoint: &str) -> Option { + if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { + // Check if this is the Responses API endpoint + if openai_api == OpenAIApi::Responses { + return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api)); + } + // Otherwise it's ChatCompletions + return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api)); + } + + if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) { + return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api)); + } + + if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) { + match bedrock_api { + AmazonBedrockApi::Converse => { + return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api)) + } + AmazonBedrockApi::ConverseStream => { + return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api)) + } + } + } + + None + } + +} + + /// Get all supported endpoint paths pub fn supported_endpoints() -> Vec<&'static str> { let mut endpoints = Vec::new(); diff --git a/crates/hermesllm/src/transforms/response/mod.rs b/crates/hermesllm/src/transforms/response/mod.rs index 3ce75123..1dd0d4ea 100644 --- a/crates/hermesllm/src/transforms/response/mod.rs +++ b/crates/hermesllm/src/transforms/response/mod.rs @@ -1,3 +1,4 @@ //! Response transformation modules +pub mod output_to_input; pub mod to_anthropic; pub mod to_openai; diff --git a/crates/hermesllm/src/transforms/response/output_to_input.rs b/crates/hermesllm/src/transforms/response/output_to_input.rs new file mode 100644 index 00000000..4f4bd34c --- /dev/null +++ b/crates/hermesllm/src/transforms/response/output_to_input.rs @@ -0,0 +1,166 @@ +//! Conversions from response outputs to request inputs for conversation continuation +//! +//! This module provides utilities for converting OutputItem types from API responses +//! into InputItem types that can be used in subsequent requests. This is primarily used +//! for maintaining conversation history in the v1/responses API. + +use crate::apis::openai_responses::{ + InputContent, InputItem, InputMessage, MessageRole, OutputContent, OutputItem, +}; + +/// Converts an OutputItem from a response into an InputItem for the next request +/// This is used to build conversation history from previous responses +pub fn output_item_to_input_item(output: &OutputItem) -> Option { + match output { + // Convert output messages to input messages + OutputItem::Message { + role, content, .. + } => { + let input_content: Vec = content + .iter() + .filter_map(|c| match c { + OutputContent::OutputText { text, .. } => Some(InputContent::InputText { + text: text.clone(), + }), + OutputContent::OutputAudio { + data, .. + } => Some(InputContent::InputAudio { + data: data.clone(), + format: None, // Format not preserved in output + }), + OutputContent::Refusal { .. } => None, // Skip refusals + }) + .collect(); + + if input_content.is_empty() { + return None; + } + + // Map role string to MessageRole enum + let message_role = match role.as_str() { + "user" => MessageRole::User, + "assistant" => MessageRole::Assistant, + "system" => MessageRole::System, + "developer" => MessageRole::Developer, + _ => MessageRole::Assistant, // Default to assistant + }; + + Some(InputItem::Message(InputMessage { + role: message_role, + content: input_content, + })) + } + // For function calls, we'll create an assistant message with the tool call info + // This matches how conversation history is typically built + OutputItem::FunctionCall { + name, arguments, .. + } => { + let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) { + format!("Called function: {} with arguments: {}", n, args) + } else { + "Called a function".to_string() + }; + + Some(InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: vec![InputContent::InputText { + text: tool_call_text, + }], + })) + } + // Skip other output types (tool outputs, etc.) as they don't convert to input + _ => None, + } +} + +/// Converts a Vec of OutputItems into InputItems for conversation continuation +pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec { + outputs + .iter() + .filter_map(output_item_to_input_item) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::openai_responses::{OutputItemStatus}; + + #[test] + fn test_output_message_to_input() { + let output = OutputItem::Message { + id: "msg_123".to_string(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: "Hello!".to_string(), + annotations: vec![], + logprobs: None, + }], + }; + + let input = output_item_to_input_item(&output).unwrap(); + + match input { + InputItem::Message(msg) => { + assert!(matches!(msg.role, MessageRole::Assistant)); + assert_eq!(msg.content.len(), 1); + match &msg.content[0] { + InputContent::InputText { text } => assert_eq!(text, "Hello!"), + _ => panic!("Expected InputText"), + } + } + } + } + + #[test] + fn test_function_call_to_input() { + let output = OutputItem::FunctionCall { + id: "fc_123".to_string(), + status: OutputItemStatus::Completed, + call_id: "call_123".to_string(), + name: Some("get_weather".to_string()), + arguments: Some(r#"{"location":"SF"}"#.to_string()), + }; + + let input = output_item_to_input_item(&output).unwrap(); + + match input { + InputItem::Message(msg) => { + assert!(matches!(msg.role, MessageRole::Assistant)); + match &msg.content[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); + } + _ => panic!("Expected InputText"), + } + } + } + } + + #[test] + fn test_outputs_to_inputs() { + let outputs = vec![ + OutputItem::Message { + id: "msg_1".to_string(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: "Hello".to_string(), + annotations: vec![], + logprobs: None, + }], + }, + OutputItem::FunctionCall { + id: "fc_1".to_string(), + status: OutputItemStatus::Completed, + call_id: "call_1".to_string(), + name: Some("test".to_string()), + arguments: Some("{}".to_string()), + }, + ]; + + let inputs = outputs_to_inputs(&outputs); + assert_eq!(inputs.len(), 2); + } +} diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index e26cc3b4..6ece5992 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -80,8 +80,15 @@ impl TryFrom for ResponsesAPIResponse { // Only add the message item if there's actual content (text, audio, or refusal) // Don't add empty message items when there are only tool calls if !content.is_empty() { + // Avoid double-prefixing: if ID already starts with "msg_", use as-is + let message_id = if resp.id.starts_with("msg_") { + resp.id.clone() + } else { + format!("msg_{}", resp.id) + }; + items.push(OutputItem::Message { - id: format!("msg_{}", resp.id), + id: message_id, status: OutputItemStatus::Completed, role: match choice.message.role { Role::User => "user".to_string(), @@ -151,7 +158,12 @@ impl TryFrom for ResponsesAPIResponse { }; Ok(ResponsesAPIResponse { - id: resp.id, + // Generate proper resp_ prefixed ID if not already present + id: if resp.id.starts_with("resp_") { + resp.id + } else { + format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")) + }, object: "response".to_string(), created_at: resp.created as i64, status, @@ -942,7 +954,7 @@ mod tests { use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse}; let chat_response = ChatCompletionsResponse { - id: "chatcmpl-123".to_string(), + id: "resp_6de5512800cf4375a329a473a4f02879".to_string(), object: Some("chat.completion".to_string()), created: 1677652288, model: "gpt-4".to_string(), @@ -974,7 +986,9 @@ mod tests { let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); - assert_eq!(responses_api.id, "chatcmpl-123"); + // Response ID should be generated with resp_ prefix + assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'"); + assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID"); assert_eq!(responses_api.object, "response"); assert_eq!(responses_api.model, "gpt-4"); diff --git a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs index 9e2f083e..30b40956 100644 --- a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs +++ b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs @@ -58,11 +58,11 @@ impl TryFrom for ChatCompletionsStreamResponse { None, )), - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) + MessagesStreamEvent::ContentBlockStart { content_block, index } => { + convert_content_block_start(content_block, index) } - MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), + MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index), MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), @@ -272,6 +272,7 @@ impl TryFrom for ChatCompletionsStreamResponse { /// Convert content block start to OpenAI chunk fn convert_content_block_start( content_block: MessagesContentBlock, + index: u32, ) -> Result { match content_block { MessagesContentBlock::Text { .. } => { @@ -291,7 +292,7 @@ fn convert_content_block_start( refusal: None, function_call: None, tool_calls: Some(vec![ToolCallDelta { - index: 0, + index, id: Some(id), call_type: Some("function".to_string()), function: Some(FunctionCallDelta { @@ -313,6 +314,7 @@ fn convert_content_block_start( /// Convert content delta to OpenAI chunk fn convert_content_delta( delta: MessagesContentDelta, + index: u32, ) -> Result { match delta { MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( @@ -350,7 +352,7 @@ fn convert_content_delta( refusal: None, function_call: None, tool_calls: Some(vec![ToolCallDelta { - index: 0, + index, id: None, call_type: None, function: Some(FunctionCallDelta { diff --git a/tests/e2e/test_openai_responses_api_client.py b/tests/e2e/test_openai_responses_api_client.py index 800db93d..7ccf1bb8 100644 --- a/tests/e2e/test_openai_responses_api_client.py +++ b/tests/e2e/test_openai_responses_api_client.py @@ -628,3 +628,204 @@ def test_openai_responses_api_streaming_with_tools_upstream_anthropic(): assert ( full_text or tool_calls ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_conversation_state_management_two_turn(): + """ + Test conversation state management across two turns: + 1. Send initial message to non-OpenAI model via v1/responses + 2. Capture response_id from first response + 3. Send second message with previous_response_id + 4. Verify model receives both messages in correct order + """ + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + logger.info("\n" + "=" * 80) + logger.info("TEST: Conversation State Management - Two Turn Flow") + logger.info("=" * 80) + + # Turn 1: Send initial message to Anthropic (non-OpenAI model) + logger.info("\n[TURN 1] Sending initial message...") + resp1 = client.responses.create( + model="claude-sonnet-4-20250514", + input="My name is Alice and I like pizza.", + ) + + # Extract response_id from first response + response_id_1 = resp1.id + logger.info(f"[TURN 1] Received response_id: {response_id_1}") + logger.info(f"[TURN 1] Model response: {resp1.output_text}") + + assert response_id_1 is not None, "First response should have an id" + assert len(resp1.output_text) > 0, "First response should have content" + + # Turn 2: Send follow-up message with previous_response_id + # Ask the model to list all messages to verify state was combined + logger.info( + f"\n[TURN 2] Sending follow-up with previous_response_id={response_id_1}" + ) + resp2 = client.responses.create( + model="claude-sonnet-4-20250514", + input="Please list all the messages you have received in our conversation, numbering each one.", + previous_response_id=response_id_1, + ) + + response_id_2 = resp2.id + logger.info(f"[TURN 2] Received response_id: {response_id_2}") + logger.info(f"[TURN 2] Model response: {resp2.output_text}") + + assert response_id_2 is not None, "Second response should have an id" + assert response_id_2 != response_id_1, "Second response should have different id" + + # Verify the model received the conversation history + # The response should reference both the initial message and the follow-up + response_lower = resp2.output_text.lower() + + # Check if the model acknowledges receiving multiple messages + # Different models might format this differently, so we check for various indicators + has_conversation_context = ( + "alice" in response_lower + or "pizza" in response_lower # References the name from turn 1 + or "two" in response_lower # References the preference from turn 1 + or "2" in response_lower # Mentions number of messages + or "first" in response_lower # Numeric indicator + or "second" # References first message + in response_lower # References second message + ) + + logger.info( + f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}" + ) + logger.info( + f"[VALIDATION] Response contains conversation markers: {has_conversation_context}" + ) + + print(f"\n{'='*80}") + print("Conversation State Test Results:") + print(f"Turn 1 Response ID: {response_id_1}") + print(f"Turn 2 Response ID: {response_id_2}") + print(f"Turn 1 Output: {resp1.output_text[:100]}...") + print(f"Turn 2 Output: {resp2.output_text}") + print(f"Conversation Context Preserved: {has_conversation_context}") + print(f"{'='*80}\n") + + assert has_conversation_context, ( + f"Model should have received conversation history. " + f"Response: {resp2.output_text}" + ) + + +def test_conversation_state_management_two_turn_streaming(): + """ + Test conversation state management across two turns with streaming: + 1. Send initial streaming message to non-OpenAI model via v1/responses + 2. Capture response_id from first response + 3. Send second streaming message with previous_response_id + 4. Verify model receives both messages in correct order + """ + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + logger.info("\n" + "=" * 80) + logger.info("TEST: Conversation State Management - Two Turn Streaming Flow") + logger.info("=" * 80) + + # Turn 1: Send initial streaming message to Anthropic (non-OpenAI model) + logger.info("\n[TURN 1] Sending initial streaming message...") + stream1 = client.responses.create( + model="claude-sonnet-4-20250514", + input="My name is Alice and I like pizza.", + stream=True, + ) + + # Collect streamed content and capture response_id + text_chunks_1 = [] + response_id_1 = None + + for event in stream1: + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + text_chunks_1.append(event.delta) + + # Capture response_id from response.completed event + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + response_id_1 = event.response.id + + output_1 = "".join(text_chunks_1) + logger.info(f"[TURN 1] Received response_id: {response_id_1}") + logger.info(f"[TURN 1] Model response: {output_1}") + + assert response_id_1 is not None, "First response should have an id" + assert len(output_1) > 0, "First response should have content" + + # Turn 2: Send follow-up streaming message with previous_response_id + logger.info( + f"\n[TURN 2] Sending follow-up streaming request with previous_response_id={response_id_1}" + ) + stream2 = client.responses.create( + model="claude-sonnet-4-20250514", + input="Please list all the messages you have received in our conversation, numbering each one.", + previous_response_id=response_id_1, + stream=True, + ) + + # Collect streamed content from second response + text_chunks_2 = [] + response_id_2 = None + + for event in stream2: + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + text_chunks_2.append(event.delta) + + # Capture response_id from response.completed event + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + response_id_2 = event.response.id + + output_2 = "".join(text_chunks_2) + logger.info(f"[TURN 2] Received response_id: {response_id_2}") + logger.info(f"[TURN 2] Model response: {output_2}") + + assert response_id_2 is not None, "Second response should have an id" + assert response_id_2 != response_id_1, "Second response should have different id" + + # Verify the model received the conversation history + response_lower = output_2.lower() + + # Check if the model acknowledges receiving multiple messages + has_conversation_context = ( + "alice" in response_lower + or "pizza" in response_lower # References the name from turn 1 + or "two" in response_lower # References the preference from turn 1 + or "2" in response_lower # Mentions number of messages + or "first" in response_lower # Numeric indicator + or "second" # References first message + in response_lower # References second message + ) + + logger.info( + f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}" + ) + logger.info( + f"[VALIDATION] Response contains conversation markers: {has_conversation_context}" + ) + + print(f"\n{'='*80}") + print("Streaming Conversation State Test Results:") + print(f"Turn 1 Response ID: {response_id_1}") + print(f"Turn 2 Response ID: {response_id_2}") + print(f"Turn 1 Output: {output_1[:100]}...") + print(f"Turn 2 Output: {output_2}") + print(f"Conversation Context Preserved: {has_conversation_context}") + print(f"{'='*80}\n") + + assert has_conversation_context, ( + f"Model should have received conversation history. " f"Response: {output_2}" + )