diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index f481b389..1474302c 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -331,6 +331,31 @@ properties: model: type: string additionalProperties: false + state_storage: + type: object + properties: + type: + type: string + enum: + - memory + - postgres + connection_string: + type: string + description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax. + additionalProperties: false + required: + - type + # Note: connection_string is conditionally required based on type + # If type is 'postgres', connection_string must be provided + # If type is 'memory', connection_string is not needed + allOf: + - if: + properties: + type: + const: postgres + then: + required: + - connection_string prompt_guards: type: object properties: diff --git a/arch/supervisord.conf b/arch/supervisord.conf index 2b715a1e..9761f779 100644 --- a/arch/supervisord.conf +++ b/arch/supervisord.conf @@ -2,7 +2,7 @@ nodaemon=true [program:brightstaff] -command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done" +command=sh -c "envsubst < /app/arch_config_rendered.yaml > /app/arch_config_rendered.env_sub.yaml && RUST_LOG=debug ARCH_CONFIG_PATH_RENDERED=/app/arch_config_rendered.env_sub.yaml /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 diff --git a/arch/tools/cli/utils.py b/arch/tools/cli/utils.py index 2f29b16e..6db34585 100644 --- a/arch/tools/cli/utils.py +++ b/arch/tools/cli/utils.py @@ -148,6 +148,24 @@ def get_llm_provider_access_keys(arch_config_file): if access_key is not None: access_key_list.append(access_key) + # Extract environment variables from state_storage.connection_string + state_storage = arch_config_yaml.get("state_storage_v1_responses") + if state_storage: + connection_string = state_storage.get("connection_string") + if connection_string and isinstance(connection_string, str): + # Extract all $VAR and ${VAR} patterns from connection string + import re + + # Match both $VAR and ${VAR} patterns + pattern = r"\$\{?([A-Z_][A-Z0-9_]*)\}?" + matches = re.findall(pattern, connection_string) + for var in matches: + access_key_list.append(f"${var}") + else: + raise ValueError( + "Invalid connection string received in state_storage_v1_responses" + ) + return access_key_list diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 09c86861..01b15dc5 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -308,11 +308,13 @@ name = "brightstaff" version = "0.1.0" dependencies = [ "async-openai", + "async-trait", "bytes", "chrono", "common", "eventsource-client", "eventsource-stream", + "flate2", "futures", "futures-util", "hermesllm", @@ -336,6 +338,7 @@ dependencies = [ "thiserror 2.0.12", "time", "tokio", + "tokio-postgres", "tokio-stream", "tracing", "tracing-opentelemetry", @@ -360,6 +363,12 @@ version = "3.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.10.1" @@ -604,6 +613,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -691,6 +701,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fancy-regex" version = "0.12.0" @@ -707,6 +723,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -986,6 +1012,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "0.2.12" @@ -1420,6 +1455,17 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags", + "libc", + "redox_syscall", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -1492,6 +1538,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "md5" version = "0.7.0" @@ -1533,6 +1589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -1836,6 +1893,24 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.10" @@ -1880,6 +1955,37 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "postgres-protocol" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4" +dependencies = [ + "base64 0.22.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.9.2", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", + "serde", + "serde_json", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2109,9 +2215,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.12" +version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ "bitflags", ] @@ -2650,12 +2756,24 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "similar" version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.9" @@ -2696,6 +2814,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -2954,6 +3083,32 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-postgres" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.2", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -3189,12 +3344,33 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -3290,6 +3466,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -3394,6 +3576,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "whoami" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +dependencies = [ + "libredox", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 2d88e213..233a4da3 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -5,11 +5,13 @@ edition = "2021" [dependencies] async-openai = "0.30.1" +async-trait = "0.1" bytes = "1.10.1" chrono = "0.4" common = { version = "0.1.0", path = "../common", features = ["trace-collection"] } eventsource-client = "0.15.0" eventsource-stream = "0.2.3" +flate2 = "1.0" futures = "0.3.31" futures-util = "0.3.31" hermesllm = { version = "0.1.0", path = "../hermesllm" } @@ -31,6 +33,7 @@ serde_with = "3.13.0" serde_yaml = "0.9.34" thiserror = "2.0.12" tokio = { version = "1.44.2", features = ["full"] } +tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] } tokio-stream = "0.1" time = { version = "0.3", features = ["formatting", "macros"] } tracing = "0.1" diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index b3686fae..5c5bcf01 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,8 +1,9 @@ use bytes::Bytes; use common::configuration::{LlmProvider, ModelAlias}; -use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; +use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; use common::traces::TraceCollector; -use hermesllm::clients::SupportedAPIsFromClient; +use hermesllm::apis::openai_responses::InputParam; +use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; @@ -11,11 +12,16 @@ use hyper::{Request, Response, StatusCode}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; use crate::router::llm_router::RouterService; use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message}; use crate::handlers::router_chat::router_chat_get_upstream_model; +use crate::state::response_state_processor::ResponsesStateProcessor; +use crate::state::{ + StateStorage, StateStorageError, + extract_input_items, retrieve_and_combine_input +}; use crate::tracing::operation_component; fn full>(chunk: T) -> BoxBody { @@ -31,14 +37,20 @@ pub async fn llm_chat( model_aliases: Arc>>, llm_providers: Arc>>, trace_collector: Arc, + state_storage: Option>, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); + let request_id = request_headers + .get(REQUEST_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".to_string()); // Extract or generate traceparent - this establishes the trace context for all spans let traceparent: String = request_headers - .get("traceparent") + .get(TRACE_PARENT_HEADER) .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| { @@ -51,7 +63,8 @@ pub async fn llm_chat( let chat_request_bytes = request.collect().await?.to_bytes(); debug!( - "Received request body (raw utf8): {}", + "[PLANO_REQ_ID:{}] | REQUEST_BODY (UTF8): {}", + request_id, String::from_utf8_lossy(&chat_request_bytes) ); @@ -61,14 +74,19 @@ pub async fn llm_chat( )) { Ok(request) => request, Err(err) => { - warn!("Failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); + warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err); + let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); } }; + // === v1/responses state management: Extract input items early === + let mut original_input_items = Vec::new(); + let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str()); + let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))); + // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model let model_from_request = client_request.model().to_string(); @@ -83,9 +101,77 @@ pub async fn llm_chat( client_request.set_model(resolved_model.clone()); if client_request.remove_metadata_key("archgw_preference_config") { - debug!("Removed archgw_preference_config from metadata"); + debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id); } + // === v1/responses state management: Determine upstream API and combine input if needed === + // Do this BEFORE routing since routing consumes the request + // Only process state if state_storage is configured + let mut should_manage_state = false; + if is_responses_api_client && state_storage.is_some() { + if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request { + // Extract original input once + original_input_items = extract_input_items(&responses_req.input); + + // Get the upstream path and check if it's ResponsesAPI + let upstream_path = get_upstream_path( + &llm_providers, + &resolved_model, + &request_path, + &resolved_model, + is_streaming_request, + ).await; + + let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); + + // Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation) + should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))); + + if should_manage_state { + // Retrieve and combine conversation history if previous_response_id exists + if let Some(ref prev_resp_id) = responses_req.previous_response_id { + match retrieve_and_combine_input( + state_storage.as_ref().unwrap().clone(), + prev_resp_id, + original_input_items, // Pass ownership instead of cloning + ) + .await + { + Ok(combined_input) => { + // Update both the request and original_input_items + responses_req.input = InputParam::Items(combined_input.clone()); + original_input_items = combined_input; + info!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Updated request with conversation history ({} items)", request_id, original_input_items.len()); + } + Err(StateStorageError::NotFound(_)) => { + // Return 409 Conflict when previous_response_id not found + warn!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Previous response_id not found: {}", request_id, prev_resp_id); + let err_msg = format!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Conversation state not found for previous_response_id: {}", + request_id, prev_resp_id + ); + let mut conflict_response = Response::new(full(err_msg)); + *conflict_response.status_mut() = StatusCode::CONFLICT; + return Ok(conflict_response); + } + Err(e) => { + // Log warning but continue on other storage errors + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to retrieve conversation state for {}: {}", + request_id, prev_resp_id, e + ); + // Restore original_input_items since we passed ownership + original_input_items = extract_input_items(&responses_req.input); + } + } + } + } else { + debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id); + } + } + } + + // Serialize request for upstream BEFORE router consumes it let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap(); // Determine routing using the dedicated router_chat module @@ -110,8 +196,8 @@ pub async fn llm_chat( let model_name = routing_result.model_name; debug!( - "[ARCH_ROUTER] URL: {}, Resolved Model: {}", - full_qualified_llm_provider_url, model_name + "[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}", + request_id, full_qualified_llm_provider_url, model_name ); request_headers.insert( @@ -173,15 +259,40 @@ pub async fn llm_chat( &llm_providers, ).await; - // Use PassthroughProcessor to track streaming metrics and finalize the span - let processor = ObservableStreamProcessor::new( + // Create base processor for metrics and tracing + let base_processor = ObservableStreamProcessor::new( trace_collector, operation_component::LLM, llm_span, request_start_time, ); - let streaming_response = create_streaming_response(byte_stream, processor, 16); + // === v1/responses state management: Wrap with ResponsesStateProcessor === + // Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured) + let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() { + // Extract Content-Encoding header to handle decompression for state parsing + let content_encoding = response_headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + // Wrap with state management processor to store state after response completes + let state_processor = ResponsesStateProcessor::new( + base_processor, + state_storage.unwrap(), + original_input_items, + resolved_model.clone(), + model_name.clone(), + is_streaming_request, + false, // Not OpenAI upstream since should_manage_state is true + content_encoding, + request_id.clone(), + ); + create_streaming_response(byte_stream, state_processor, 16) + } else { + // Use base processor without state management + create_streaming_response(byte_stream, base_processor, 16) + }; match response.body(streaming_response.body) { Ok(response) => Ok(response), @@ -301,35 +412,7 @@ async fn get_upstream_path( resolved_model: &str, is_streaming: bool, ) -> String { - let providers_lock = llm_providers.read().await; - - // First, try to find by model name or provider name - let provider = providers_lock.iter().find(|p| { - p.model.as_ref().map(|m| m == model_name).unwrap_or(false) - || p.name == model_name - }); - - let (provider_id, base_url_path_prefix) = if let Some(provider) = provider { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - (provider_id, prefix) - } else { - let default_provider = providers_lock.iter().find(|p| { - p.default.unwrap_or(false) - }); - - if let Some(provider) = default_provider { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - (provider_id, prefix) - } else { - // Last resort: use OpenAI as hardcoded fallback - warn!("No default provider found, falling back to OpenAI"); - (hermesllm::ProviderId::OpenAI, None) - } - }; - - drop(providers_lock); + let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; // Calculate the upstream path using the proper API let client_api = SupportedAPIsFromClient::from_endpoint(request_path) @@ -343,3 +426,37 @@ async fn get_upstream_path( base_url_path_prefix.as_deref(), ) } + +/// Helper function to get provider info (ProviderId and base_url_path_prefix) +async fn get_provider_info( + llm_providers: &Arc>>, + model_name: &str, +) -> (hermesllm::ProviderId, Option) { + let providers_lock = llm_providers.read().await; + + // First, try to find by model name or provider name + let provider = providers_lock.iter().find(|p| { + p.model.as_ref().map(|m| m == model_name).unwrap_or(false) + || p.name == model_name + }); + + if let Some(provider) = provider { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + return (provider_id, prefix); + } + + let default_provider = providers_lock.iter().find(|p| { + p.default.unwrap_or(false) + }); + + if let Some(provider) = default_provider { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + (provider_id, prefix) + } else { + // Last resort: use OpenAI as hardcoded fallback + warn!("No default provider found, falling back to OpenAI"); + (hermesllm::ProviderId::OpenAI, None) + } +} diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index 09b09975..a927a0eb 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -1,4 +1,5 @@ use common::configuration::ModelUsagePreference; +use common::consts::{REQUEST_ID_HEADER}; use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent}; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; @@ -43,6 +44,10 @@ pub async fn router_chat_get_upstream_model( ) -> Result { // Clone metadata for routing before converting (which consumes client_request) let routing_metadata = client_request.metadata().clone(); + let request_id = request_headers + .get(REQUEST_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .unwrap_or("unknown"); // Convert to ChatCompletionsRequest for routing (regardless of input type) let chat_request = match ProviderRequestType::try_from(( @@ -73,7 +78,8 @@ pub async fn router_chat_get_upstream_model( }; debug!( - "[ARCH_ROUTER REQ]: {}", + "[PLANO_REQ_ID: {}]: ROUTER_REQ: {}", + request_id, &serde_json::to_string(&chat_request).unwrap() ); @@ -114,14 +120,13 @@ pub async fn router_chat_get_upstream_model( }; info!( - "request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}", + "[PLANO_REQ_ID: {}] | ROUTER_REQ | Usage preferences from request: {}, request_path: {}, latest message: {}", + request_id, usage_preferences.is_some(), request_path, latest_message_for_log ); - debug!("usage preferences from request: {:?}", usage_preferences); - // Capture start time for routing span let routing_start_time = std::time::Instant::now(); let routing_start_system_time = std::time::SystemTime::now(); @@ -153,7 +158,8 @@ pub async fn router_chat_get_upstream_model( None => { // No route determined, use default model from request info!( - "No route determined, using default model from request: {}", + "[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}", + request_id, chat_request.model ); diff --git a/crates/brightstaff/src/lib.rs b/crates/brightstaff/src/lib.rs index ceff49f1..36fc902f 100644 --- a/crates/brightstaff/src/lib.rs +++ b/crates/brightstaff/src/lib.rs @@ -1,4 +1,5 @@ pub mod handlers; pub mod router; +pub mod state; pub mod tracing; pub mod utils; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index d0241fa3..325280e8 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -3,6 +3,9 @@ use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::function_calling::{function_calling_chat_handler}; use brightstaff::router::llm_router::RouterService; +use brightstaff::state::StateStorage; +use brightstaff::state::postgresql::PostgreSQLConversationStorage; +use brightstaff::state::memory::MemoryConversationalStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::Configuration; @@ -101,6 +104,37 @@ async fn main() -> Result<(), Box> { let trace_collector = Arc::new(TraceCollector::new(tracing_enabled)); let _flusher_handle = trace_collector.clone().start_background_flusher(); + // Initialize conversation state storage for v1/responses + // Configurable via arch_config.yaml state_storage section + // If not configured, state management is disabled + // Environment variables are substituted by envsubst before config is read + let state_storage: Option> = if let Some(storage_config) = &arch_config.state_storage { + let storage: Arc = match storage_config.storage_type { + common::configuration::StateStorageType::Memory => { + info!("Initialized conversation state storage: Memory"); + Arc::new(MemoryConversationalStorage::new()) + } + common::configuration::StateStorageType::Postgres => { + let connection_string = storage_config + .connection_string + .as_ref() + .expect("connection_string is required for postgres state_storage"); + + debug!("Postgres connection string (full): {}", connection_string); + info!("Initializing conversation state storage: Postgres"); + Arc::new( + PostgreSQLConversationStorage::new(connection_string.clone()) + .await + .expect("Failed to initialize Postgres state storage"), + ) + } + }; + Some(storage) + } else { + info!("No state_storage configured - conversation state management disabled"); + None + }; + loop { let (stream, _) = listener.accept().await?; @@ -115,6 +149,7 @@ async fn main() -> Result<(), Box> { let agents_list = agents_list.clone(); let listeners = listeners.clone(); let trace_collector = trace_collector.clone(); + let state_storage = state_storage.clone(); let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); @@ -124,13 +159,14 @@ async fn main() -> Result<(), Box> { let agents_list = agents_list.clone(); let listeners = listeners.clone(); let trace_collector = trace_collector.clone(); + let state_storage = state_storage.clone(); async move { match (req.method(), req.uri().path()) { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); - llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector) + llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage) .with_context(parent_cx) .await } diff --git a/crates/brightstaff/src/state/memory.rs b/crates/brightstaff/src/state/memory.rs new file mode 100644 index 00000000..d805d655 --- /dev/null +++ b/crates/brightstaff/src/state/memory.rs @@ -0,0 +1,611 @@ +use super::{OpenAIConversationState, StateStorage, StateStorageError}; +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, warn}; + +/// In-memory storage backend for conversation state +/// Uses a HashMap wrapped in Arc> for thread-safe access +#[derive(Clone)] +pub struct MemoryConversationalStorage { + storage: Arc>>, +} + +impl MemoryConversationalStorage { + pub fn new() -> Self { + Self { + storage: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +impl Default for MemoryConversationalStorage { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl StateStorage for MemoryConversationalStorage { + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> { + let response_id = state.response_id.clone(); + let mut storage = self.storage.write().await; + + debug!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}", + response_id, state.model, state.provider, state.input_items.len() + ); + + storage.insert(response_id, state); + Ok(()) + } + + async fn get(&self, response_id: &str) -> Result { + let storage = self.storage.read().await; + + match storage.get(response_id) { + Some(state) => { + debug!( + "[PLANO | MEMORY_STORAGE | RESP_ID:{} | Retrieved conversation state: input_items={}", + response_id, state.input_items.len() + ); + Ok(state.clone()) + } + None => { + warn!( + "[PLANO_RESP_ID:{} | MEMORY_STORAGE | Conversation state not found", + response_id + ); + Err(StateStorageError::NotFound(response_id.to_string())) + } + } + } + + async fn exists(&self, response_id: &str) -> Result { + let storage = self.storage.read().await; + Ok(storage.contains_key(response_id)) + } + + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> { + let mut storage = self.storage.write().await; + + if storage.remove(response_id).is_some() { + debug!( + "[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state", + response_id + ); + Ok(()) + } else { + Err(StateStorageError::NotFound(response_id.to_string())) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent}; + + fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { + let mut input_items = Vec::new(); + for i in 0..num_messages { + input_items.push(InputItem::Message(InputMessage { + role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, + content: MessageContent::Items(vec![InputContent::InputText { + text: format!("Message {}", i), + }]), + })); + } + + OpenAIConversationState { + response_id: response_id.to_string(), + input_items, + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + } + } + + #[tokio::test] + async fn test_put_and_get_success() { + let storage = MemoryConversationalStorage::new(); + let state: OpenAIConversationState = create_test_state("resp_001", 3); + + // Store + storage.put(state.clone()).await.unwrap(); + + // Retrieve + let retrieved = storage.get("resp_001").await.unwrap(); + assert_eq!(retrieved.response_id, state.response_id); + assert_eq!(retrieved.model, state.model); + assert_eq!(retrieved.provider, state.provider); + assert_eq!(retrieved.input_items.len(), 3); + assert_eq!(retrieved.created_at, state.created_at); + } + + #[tokio::test] + async fn test_put_overwrites_existing() { + let storage = MemoryConversationalStorage::new(); + + // First state + let state1 = create_test_state("resp_002", 2); + storage.put(state1).await.unwrap(); + + // Overwrite with new state + let state2 = OpenAIConversationState { + response_id: "resp_002".to_string(), + input_items: vec![], + created_at: 9999999999, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + storage.put(state2.clone()).await.unwrap(); + + // Should retrieve the new state + let retrieved = storage.get("resp_002").await.unwrap(); + assert_eq!(retrieved.model, "gpt-4"); + assert_eq!(retrieved.provider, "openai"); + assert_eq!(retrieved.input_items.len(), 0); + assert_eq!(retrieved.created_at, 9999999999); + } + + #[tokio::test] + async fn test_get_not_found() { + let storage = MemoryConversationalStorage::new(); + + let result = storage.get("nonexistent").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::NotFound(id) => { + assert_eq!(id, "nonexistent"); + } + _ => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_exists_returns_false_for_nonexistent() { + let storage = MemoryConversationalStorage::new(); + assert!(!storage.exists("resp_003").await.unwrap()); + } + + #[tokio::test] + async fn test_exists_returns_true_after_put() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_004", 1); + + assert!(!storage.exists("resp_004").await.unwrap()); + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_004").await.unwrap()); + } + + #[tokio::test] + async fn test_delete_success() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_005", 2); + + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_005").await.unwrap()); + + // Delete + storage.delete("resp_005").await.unwrap(); + + // Should no longer exist + assert!(!storage.exists("resp_005").await.unwrap()); + assert!(storage.get("resp_005").await.is_err()); + } + + #[tokio::test] + async fn test_delete_not_found() { + let storage = MemoryConversationalStorage::new(); + + let result = storage.delete("nonexistent").await; + assert!(result.is_err()); + + match result.unwrap_err() { + StateStorageError::NotFound(id) => { + assert_eq!(id, "nonexistent"); + } + _ => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_merge_combines_inputs() { + let storage = MemoryConversationalStorage::new(); + + // Create a previous state with 2 messages + let prev_state = create_test_state("resp_006", 2); + + // Create current input with 1 message + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "New message".to_string(), + }]), + })]; + + // Merge + let merged = storage.merge(&prev_state, current_input); + + // Should have 3 messages total (2 from prev + 1 current) + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_preserves_order() { + let storage = MemoryConversationalStorage::new(); + + // Previous state has messages 0 and 1 + let prev_state = create_test_state("resp_007", 2); + + // Current input has message 2 + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Message 2".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Verify order: prev messages first, then current + let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") }; + match &msg.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 0"), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") }; + match &msg.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 2"), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_with_empty_current_input() { + let storage = MemoryConversationalStorage::new(); + let prev_state = create_test_state("resp_008", 3); + + let merged = storage.merge(&prev_state, vec![]); + + // Should just have the previous state's items + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_with_empty_previous_state() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_009".to_string(), + input_items: vec![], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Only message".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should just have the current input + assert_eq!(merged.len(), 1); + } + + #[tokio::test] + async fn test_concurrent_access() { + let storage = MemoryConversationalStorage::new(); + + // Spawn multiple tasks that write concurrently + let mut handles = vec![]; + + for i in 0..10 { + let storage_clone = storage.clone(); + let handle = tokio::spawn(async move { + let state = create_test_state(&format!("resp_{}", i), i % 3); + storage_clone.put(state).await.unwrap(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Verify all states were stored + for i in 0..10 { + assert!(storage.exists(&format!("resp_{}", i)).await.unwrap()); + } + } + + #[tokio::test] + async fn test_multiple_operations_on_same_id() { + let storage = MemoryConversationalStorage::new(); + let state = create_test_state("resp_010", 1); + + // Put + storage.put(state.clone()).await.unwrap(); + + // Get + let retrieved = storage.get("resp_010").await.unwrap(); + assert_eq!(retrieved.response_id, "resp_010"); + + // Exists + assert!(storage.exists("resp_010").await.unwrap()); + + // Put again (overwrite) + let new_state = create_test_state("resp_010", 5); + storage.put(new_state).await.unwrap(); + + // Get updated + let updated = storage.get("resp_010").await.unwrap(); + assert_eq!(updated.input_items.len(), 5); + + // Delete + storage.delete("resp_010").await.unwrap(); + + // Should not exist + assert!(!storage.exists("resp_010").await.unwrap()); + } + + #[tokio::test] + async fn test_merge_with_tool_call_flow() { + // This test simulates a realistic tool call conversation flow: + // 1. User sends message: "What's the weather?" + // 2. Model responds with function call (converted to assistant message) + // 3. User sends function call output in next request with previous_response_id + // The merge should combine: user message + assistant function call + function output + + let storage = MemoryConversationalStorage::new(); + + // Step 1: Previous state contains the initial exchange + // - User message: "What's the weather in SF?" + // - Assistant message (converted from FunctionCall): "Called function: get_weather..." + let prev_state = OpenAIConversationState { + response_id: "resp_tool_001".to_string(), + input_items: vec![ + // Original user message + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather in San Francisco?".to_string(), + }]), + }), + // Assistant's function call (converted from OutputItem::FunctionCall) + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + // Step 2: Current request includes function call output + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(), + }]), + })]; + + // Step 3: Merge should combine all conversation history + let merged = storage.merge(&prev_state, current_input); + + // Should have 3 items: user question + assistant function call + function output + assert_eq!(merged.len(), 3); + + // Verify the order and content + let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") }; + assert!(matches!(msg1.role, MessageRole::User)); + match &msg1.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("weather in San Francisco")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") }; + assert!(matches!(msg2.role, MessageRole::Assistant)); + match &msg2.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") }; + assert!(matches!(msg3.role, MessageRole::User)); + match &msg3.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("Function result")); + assert!(text.contains("temperature")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_with_multiple_tool_calls() { + // Test a more complex scenario with multiple tool calls + let storage = MemoryConversationalStorage::new(); + + // Previous state has: user message + 2 function calls from assistant + let prev_state = OpenAIConversationState { + response_id: "resp_tool_002".to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather and time in SF?".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + // Current input: function outputs for both calls + let current_input = vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Weather result: {\"temp\": 68}".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Time result: {\"time\": \"14:30\"}".to_string(), + }]), + }), + ]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs + assert_eq!(merged.len(), 5); + + // Verify first item is original user message + let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; + assert!(matches!(first.role, MessageRole::User)); + + // Verify last two are function outputs + let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") }; + assert!(matches!(second_last.role, MessageRole::User)); + match &second_last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("Weather result")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; + assert!(matches!(last.role, MessageRole::User)); + match &last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("Time result")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_preserves_conversation_context_for_multi_turn() { + // Simulate a multi-turn conversation with tool calls + let storage = MemoryConversationalStorage::new(); + + // Previous state: full conversation history up to this point + let prev_state = OpenAIConversationState { + response_id: "resp_tool_003".to_string(), + input_items: vec![ + // Turn 1: User asks about weather + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather?".to_string(), + }]), + }), + // Turn 1: Assistant calls get_weather + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather".to_string(), + }]), + }), + // Turn 2: User provides function output + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Weather: sunny, 72°F".to_string(), + }]), + }), + // Turn 2: Assistant responds with text + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "It's sunny and 72°F in San Francisco today!".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + // Turn 3: User asks follow-up question + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Should I bring an umbrella?".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have all 5 messages in order + assert_eq!(merged.len(), 5); + + // Verify the entire conversation flow is preserved + let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; + match &first.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("What's the weather")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; + match &last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("umbrella")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } +} diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs new file mode 100644 index 00000000..f2b96da0 --- /dev/null +++ b/crates/brightstaff/src/state/mod.rs @@ -0,0 +1,147 @@ +use async_trait::async_trait; +use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam}; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use tracing::{debug}; + +pub mod memory; +pub mod response_state_processor; +pub mod postgresql; + +/// Represents the conversational state for a v1/responses request +/// Contains the complete input/output history that can be restored +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAIConversationState { + /// The response ID this state is associated with + pub response_id: String, + + /// The complete input history (original input + accumulated outputs) + /// This is what gets prepended to new requests via previous_response_id + pub input_items: Vec, + + /// Timestamp when this state was created + pub created_at: i64, + + /// Model used for this response + pub model: String, + + /// Provider that generated this response (e.g., "anthropic", "openai") + pub provider: String, +} + +/// Error types for state storage operations +#[derive(Debug)] +pub enum StateStorageError { + /// State not found for given response_id + NotFound(String), + + /// Storage backend error (network, database, etc.) + StorageError(String), + + /// Serialization/deserialization error + SerializationError(String), +} + +impl fmt::Display for StateStorageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id), + StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg), + StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg), + } + } +} + +impl Error for StateStorageError {} + +/// Trait for conversation state storage backends +#[async_trait] +pub trait StateStorage: Send + Sync { + /// Store conversation state for a response + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>; + + /// Retrieve conversation state by response_id + async fn get(&self, response_id: &str) -> Result; + + /// Check if state exists for a response_id + async fn exists(&self, response_id: &str) -> Result; + + /// Delete state for a response_id (optional, for cleanup) + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>; + + fn merge( + &self, + prev_state: &OpenAIConversationState, + current_input: Vec, + ) -> Vec { + // Default implementation: prepend previous input, append current + let prev_count = prev_state.input_items.len(); + let current_count = current_input.len(); + + let mut combined_input = prev_state.input_items.clone(); + combined_input.extend(current_input); + + debug!( + "PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}", + prev_state.response_id, + prev_count, + current_count, + combined_input.len(), + serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + ); + + combined_input + } +} + + + +/// Storage backend type enum +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StorageBackend { + Memory, + Supabase, +} + +impl StorageBackend { + pub fn from_str(s: &str) -> Option { + match s.to_lowercase().as_str() { + "memory" => Some(StorageBackend::Memory), + "supabase" => Some(StorageBackend::Supabase), + _ => None, + } + } +} + +// === Utility functions for state management === + +/// Extract input items from InputParam, converting text to structured format +pub fn extract_input_items(input: &InputParam) -> Vec { + match input { + InputParam::Text(text) => { + vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: text.clone(), + }]), + })] + } + InputParam::Items(items) => items.clone(), + } +} + +/// Retrieve previous conversation state and combine with current input +/// Returns combined input if previous state found, or original input if not found/error +pub async fn retrieve_and_combine_input( + storage: Arc, + previous_response_id: &str, + current_input: Vec, +) -> Result, StateStorageError> { + + // First get the previous state + let prev_state = storage.get(previous_response_id).await?; + let combined_input = storage.merge(&prev_state, current_input); + Ok(combined_input) +} diff --git a/crates/brightstaff/src/state/postgresql.rs b/crates/brightstaff/src/state/postgresql.rs new file mode 100644 index 00000000..529f27e9 --- /dev/null +++ b/crates/brightstaff/src/state/postgresql.rs @@ -0,0 +1,432 @@ +use super::{OpenAIConversationState, StateStorage, StateStorageError}; +use async_trait::async_trait; +use serde_json; +use std::sync::Arc; +use tokio::sync::OnceCell; +use tokio_postgres::{Client, NoTls}; +use tracing::{debug, info, warn}; + +/// Supabase/PostgreSQL storage backend for conversation state +#[derive(Clone)] +pub struct PostgreSQLConversationStorage { + client: Arc, + table_verified: Arc>, +} + +impl PostgreSQLConversationStorage { + /// Creates a new Supabase storage instance with the given connection string + pub async fn new(connection_string: String) -> Result { + let (client, connection) = tokio_postgres::connect(&connection_string, NoTls) + .await + .map_err(|e| { + StateStorageError::StorageError(format!("Failed to connect to database: {}", e)) + })?; + + // Spawn the connection to run in the background + tokio::spawn(async move { + if let Err(e) = connection.await { + warn!("Database connection error: {}", e); + } + }); + + Ok(Self { + client: Arc::new(client), + table_verified: Arc::new(OnceCell::new()), + }) + } + + /// Ensures the conversation_states table exists (checks once, caches result) + async fn ensure_ready(&self) -> Result<(), StateStorageError> { + self.table_verified + .get_or_try_init(|| async { + let row = self + .client + .query_one( + "SELECT EXISTS ( + SELECT FROM pg_tables + WHERE tablename = 'conversation_states' + )", + &[], + ) + .await + .map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to verify table existence: {}", + e + )) + })?; + + let exists: bool = row.get(0); + + if !exists { + return Err(StateStorageError::StorageError( + "Table 'conversation_states' does not exist. \ + Please run the setup SQL from docs/db_setup/conversation_states.sql" + .to_string(), + )); + } + + info!("Conversation state storage table verified"); + Ok(()) + }) + .await?; + + Ok(()) + } +} + +#[async_trait] +impl StateStorage for PostgreSQLConversationStorage { + async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> { + self.ensure_ready().await?; + + // Serialize input_items to JSONB + let input_items_json = serde_json::to_value(&state.input_items).map_err(|e| { + StateStorageError::StorageError(format!("Failed to serialize input_items: {}", e)) + })?; + + // Upsert the conversation state + self.client + .execute( + r#" + INSERT INTO conversation_states + (response_id, input_items, created_at, model, provider, updated_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (response_id) + DO UPDATE SET + input_items = EXCLUDED.input_items, + model = EXCLUDED.model, + provider = EXCLUDED.provider, + updated_at = NOW() + "#, + &[ + &state.response_id, + &input_items_json, + &state.created_at, + &state.model, + &state.provider, + ], + ) + .await + .map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to store conversation state for {}: {}", + state.response_id, e + )) + })?; + + debug!("Stored conversation state for {}", state.response_id); + Ok(()) + } + + async fn get(&self, response_id: &str) -> Result { + self.ensure_ready().await?; + + let row = self + .client + .query_opt( + r#" + SELECT response_id, input_items, created_at, model, provider + FROM conversation_states + WHERE response_id = $1 + "#, + &[&response_id], + ) + .await + .map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to fetch conversation state for {}: {}", + response_id, e + )) + })?; + + match row { + Some(row) => { + let response_id: String = row.get("response_id"); + let input_items_json: serde_json::Value = row.get("input_items"); + let created_at: i64 = row.get("created_at"); + let model: String = row.get("model"); + let provider: String = row.get("provider"); + + // Deserialize input_items from JSONB + let input_items = + serde_json::from_value(input_items_json).map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to deserialize input_items: {}", + e + )) + })?; + + Ok(OpenAIConversationState { + response_id, + input_items, + created_at, + model, + provider, + }) + } + None => Err(StateStorageError::NotFound(format!( + "Conversation state not found for response_id: {}", + response_id + ))), + } + } + + async fn exists(&self, response_id: &str) -> Result { + self.ensure_ready().await?; + + let row = self + .client + .query_one( + "SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)", + &[&response_id], + ) + .await + .map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to check existence for {}: {}", + response_id, e + )) + })?; + + let exists: bool = row.get(0); + Ok(exists) + } + + async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> { + self.ensure_ready().await?; + + let rows_affected = self + .client + .execute( + "DELETE FROM conversation_states WHERE response_id = $1", + &[&response_id], + ) + .await + .map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to delete conversation state for {}: {}", + response_id, e + )) + })?; + + if rows_affected == 0 { + return Err(StateStorageError::NotFound(format!( + "Conversation state not found for response_id: {}", + response_id + ))); + } + + debug!("Deleted conversation state for {}", response_id); + Ok(()) + } +} + +/* +PostgreSQL schema is maintained in docs/db_setup/conversation_states.sql +Run that SQL file against your database before using this storage backend. +*/ + +#[cfg(test)] +mod tests { + use super::*; + use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole}; + + fn create_test_state(response_id: &str) -> OpenAIConversationState { + OpenAIConversationState { + response_id: response_id.to_string(), + input_items: vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Test message".to_string(), + }]), + })], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + } + } + + // Note: These tests require a running PostgreSQL database + // Set TEST_DATABASE_URL environment variable to run integration tests + // Example: TEST_DATABASE_URL=postgresql://user:pass@localhost/test_db + + async fn get_test_storage() -> Option { + if let Ok(db_url) = std::env::var("TEST_DATABASE_URL") { + match PostgreSQLConversationStorage::new(db_url).await { + Ok(storage) => Some(storage), + Err(e) => { + eprintln!("Failed to create test storage: {}", e); + None + } + } + } else { + eprintln!("TEST_DATABASE_URL not set, skipping Supabase integration tests"); + None + } + } + + #[tokio::test] + async fn test_supabase_put_and_get_success() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let state = create_test_state("test_resp_001"); + storage.put(state.clone()).await.unwrap(); + + let retrieved = storage.get("test_resp_001").await.unwrap(); + assert_eq!(retrieved.response_id, "test_resp_001"); + assert_eq!(retrieved.input_items.len(), 1); + assert_eq!(retrieved.model, "gpt-4"); + assert_eq!(retrieved.provider, "openai"); + + // Cleanup + let _ = storage.delete("test_resp_001").await; + } + + #[tokio::test] + async fn test_supabase_put_overwrites_existing() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let state1 = create_test_state("test_resp_002"); + storage.put(state1).await.unwrap(); + + let mut state2 = create_test_state("test_resp_002"); + state2.model = "gpt-4-turbo".to_string(); + state2.input_items.push(InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Response".to_string(), + }]), + })); + storage.put(state2).await.unwrap(); + + let retrieved = storage.get("test_resp_002").await.unwrap(); + assert_eq!(retrieved.model, "gpt-4-turbo"); + assert_eq!(retrieved.input_items.len(), 2); + + // Cleanup + let _ = storage.delete("test_resp_002").await; + } + + #[tokio::test] + async fn test_supabase_get_not_found() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let result = storage.get("nonexistent_id").await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); + } + + #[tokio::test] + async fn test_supabase_exists_returns_false() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let exists = storage.exists("nonexistent_id").await.unwrap(); + assert!(!exists); + } + + #[tokio::test] + async fn test_supabase_exists_returns_true_after_put() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let state = create_test_state("test_resp_003"); + storage.put(state).await.unwrap(); + + let exists = storage.exists("test_resp_003").await.unwrap(); + assert!(exists); + + // Cleanup + let _ = storage.delete("test_resp_003").await; + } + + #[tokio::test] + async fn test_supabase_delete_success() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let state = create_test_state("test_resp_004"); + storage.put(state).await.unwrap(); + + storage.delete("test_resp_004").await.unwrap(); + + let exists = storage.exists("test_resp_004").await.unwrap(); + assert!(!exists); + } + + #[tokio::test] + async fn test_supabase_delete_not_found() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let result = storage.delete("nonexistent_id").await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); + } + + #[tokio::test] + async fn test_supabase_merge_works() { + let Some(storage) = get_test_storage().await else { + return; + }; + + let prev_state = create_test_state("test_resp_005"); + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "New message".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + + // Should have 2 messages (1 from prev + 1 current) + assert_eq!(merged.len(), 2); + } + + #[tokio::test] + async fn test_supabase_table_verification() { + let Some(storage) = get_test_storage().await else { + return; + }; + + // This should trigger table verification + let result = storage.ensure_ready().await; + assert!(result.is_ok(), "Table verification should succeed"); + + // Second call should use cached result + let result2 = storage.ensure_ready().await; + assert!(result2.is_ok(), "Cached verification should succeed"); + } + + #[tokio::test] + #[ignore] // Run manually with: cargo test test_verify_data_in_supabase -- --ignored + async fn test_verify_data_in_supabase() { + let Some(storage) = get_test_storage().await else { + return; + }; + + // Create a test record that persists + let state = create_test_state("manual_test_verification"); + storage.put(state).await.unwrap(); + + println!("✅ Data written to Supabase!"); + println!("Check your Supabase dashboard:"); + println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"); + println!("\nTo cleanup, run:"); + println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"); + + // DON'T cleanup - leave it for manual verification + } +} diff --git a/crates/brightstaff/src/state/response_state_processor.rs b/crates/brightstaff/src/state/response_state_processor.rs new file mode 100644 index 00000000..b3ce6787 --- /dev/null +++ b/crates/brightstaff/src/state/response_state_processor.rs @@ -0,0 +1,302 @@ +use bytes::Bytes; +use flate2::read::GzDecoder; +use hermesllm::apis::openai_responses::{ + InputItem, OutputItem, ResponsesAPIStreamEvent, +}; +use hermesllm::apis::streaming_shapes::sse::SseStreamIter; +use hermesllm::transforms::response::output_to_input::outputs_to_inputs; +use std::io::Read; +use std::sync::Arc; +use tracing::{info, debug, warn}; + +use crate::handlers::utils::StreamProcessor; +use crate::state::{OpenAIConversationState, StateStorage}; + +/// Processor that wraps another processor and handles v1/responses state management +/// Captures response_id and output from streaming responses, stores state after completion +pub struct ResponsesStateProcessor { + /// The underlying processor (e.g., ObservableStreamProcessor for metrics) + inner: P, + + /// State storage backend + storage: Arc, + + /// Original input items from the request + original_input: Vec, + + /// Model name + model: String, + + /// Provider name + provider: String, + + /// Whether this is a streaming request + is_streaming: bool, + + /// Whether upstream is OpenAI (skip storage if true) + is_openai_upstream: bool, + + /// Content-Encoding header value (e.g., "gzip", "br", None) + content_encoding: Option, + + /// Request ID for logging + request_id: String, + + /// Buffer for accumulating chunks (needed for non-streaming compressed responses) + chunk_buffer: Vec, + + /// Captured response_id from response.completed event + response_id: Option, + + /// Captured output items from response.completed event + output_items: Option>, +} + +impl ResponsesStateProcessor

{ + pub fn new( + inner: P, + storage: Arc, + original_input: Vec, + model: String, + provider: String, + is_streaming: bool, + is_openai_upstream: bool, + content_encoding: Option, + request_id: String, + ) -> Self { + Self { + inner, + storage, + original_input, + model, + provider, + is_streaming, + is_openai_upstream, + content_encoding, + request_id, + chunk_buffer: Vec::new(), + response_id: None, + output_items: None, + } + } + + /// Decompress accumulated buffer based on Content-Encoding header + fn decompress_buffer(&self) -> Vec { + if self.chunk_buffer.is_empty() { + return Vec::new(); + } + + match self.content_encoding.as_deref() { + Some("gzip") => { + let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice()); + let mut decompressed = Vec::new(); + match decoder.read_to_end(&mut decompressed) { + Ok(_) => { + debug!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes", + self.request_id, + self.chunk_buffer.len(), + decompressed.len() + ); + decompressed + } + Err(e) => { + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to decompress gzip buffer: {}", + self.request_id, + e + ); + self.chunk_buffer.clone() + } + } + } + Some(encoding) => { + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.", + self.request_id, + encoding + ); + self.chunk_buffer.clone() + } + None => self.chunk_buffer.clone(), + } + } + + /// Parse response to extract response_id and output + /// For streaming: parse SSE events looking for response.completed (per chunk) + /// For non-streaming: buffer all chunks, then decompress and parse on completion + fn try_parse_response_chunk(&mut self, chunk: &[u8]) { + if self.is_streaming { + // Streaming: Try to parse SSE events from this chunk + // Note: For compressed streaming, we'd need to buffer and decompress first + // but most streaming responses aren't compressed since SSE needs to be readable + let sse_iter = match SseStreamIter::try_from(chunk) { + Ok(iter) => iter, + Err(_) => return, // Not valid SSE format, skip + }; + + // Process each SSE event in the chunk, looking for data lines with response.completed + for event in sse_iter { + // Only process data lines (skip event-only lines) + if let Some(data_str) = &event.data { + // Try to parse as ResponsesAPIStreamEvent + if let Ok(stream_event) = serde_json::from_str::(data_str) { + // Check if this is a ResponseCompleted event + if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event { + info!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}", + self.request_id, + response.id, + response.output.len() + ); + self.response_id = Some(response.id.clone()); + self.output_items = Some(response.output.clone()); + return; // Found what we need, exit early + } + } + } + } + } else { + // Non-streaming: Buffer chunks, will decompress and parse on completion + self.chunk_buffer.extend_from_slice(chunk); + } + } + + /// Parse buffered non-streaming response (called on completion) + fn try_parse_buffered_response(&mut self) { + if self.is_streaming || self.chunk_buffer.is_empty() { + return; + } + + // Decompress if needed + let decompressed = self.decompress_buffer(); + + // Parse complete JSON response + match serde_json::from_slice::(&decompressed) { + Ok(response) => { + info!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}", + self.request_id, + response.id, + response.output.len() + ); + self.response_id = Some(response.id.clone()); + self.output_items = Some(response.output.clone()); + } + Err(e) => { + // Log parse error with chunk preview for debugging + let chunk_preview = String::from_utf8_lossy(&decompressed); + let preview_len = chunk_preview.len().min(200); + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}", + self.request_id, + e, + preview_len, + &chunk_preview[..preview_len] + ); + } + } + } +} + +impl StreamProcessor for ResponsesStateProcessor

{ + fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { + // Buffer/parse chunk for response extraction + self.try_parse_response_chunk(&chunk); + + // Forward to inner processor + self.inner.process_chunk(chunk) + } + + fn on_first_bytes(&mut self) { + self.inner.on_first_bytes(); + } + + fn on_complete(&mut self) { + // For non-streaming, decompress and parse buffered response + self.try_parse_buffered_response(); + + // First, let the inner processor complete + self.inner.on_complete(); + + // Skip storage for OpenAI upstream + if self.is_openai_upstream { + debug!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider", + self.request_id + ); + return; + } + + // Store state if we captured response_id and output + if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) { + // Convert output items to input items for next request + let output_as_inputs = outputs_to_inputs(output_items); + + debug!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}", + self.request_id, output_items.len(), output_as_inputs.len() + ); + + // Combine original input + output as new input history + let mut combined_input = self.original_input.clone(); + combined_input.extend(output_as_inputs); + + debug!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}", + self.request_id, + self.original_input.len(), + combined_input.len(), + serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + ); + + let state = OpenAIConversationState { + response_id: response_id.clone(), + input_items: combined_input, + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64, + model: self.model.clone(), + provider: self.provider.clone(), + }; + + // Store asynchronously (fire and forget with logging) + let storage = self.storage.clone(); + let response_id_clone = response_id.clone(); + let request_id = self.request_id.clone(); + let items_count = state.input_items.len(); + tokio::spawn(async move { + match storage.put(state).await { + Ok(()) => { + info!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}, items_count={}", + request_id, + response_id_clone, + items_count + ); + } + Err(e) => { + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}", + request_id, + response_id_clone, + e + ); + } + } + }); + } else { + warn!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}", + self.request_id, + self.response_id.is_some(), + self.output_items.is_some() + ); + } + } + + fn on_error(&mut self, error: &str) { + self.inner.on_error(error); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 27f8ebd9..2a4f983a 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -41,6 +41,20 @@ pub struct Listener { pub port: u16, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateStorageConfig { + #[serde(rename = "type")] + pub storage_type: StateStorageType, + pub connection_string: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum StateStorageType { + Memory, + Postgres, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -58,6 +72,7 @@ pub struct Configuration { pub routing: Option, pub agents: Option>, pub listeners: Vec, + pub state_storage: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs index 84854af3..ca8a9cfd 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -59,6 +59,11 @@ pub struct ResponsesAPIStreamBuffer { model: Option, created_at: Option, + /// Full response metadata from upstream (tools, temperature, etc.) + /// This is extracted from the first upstream event and used to build + /// complete response.created and response.in_progress events + upstream_response_metadata: Option, + /// Lifecycle state flags created_emitted: bool, in_progress_emitted: bool, @@ -88,6 +93,7 @@ impl ResponsesAPIStreamBuffer { response_id: None, model: None, created_at: None, + upstream_response_metadata: None, created_emitted: false, in_progress_emitted: false, output_items_added: HashMap::new(), @@ -171,6 +177,15 @@ impl ResponsesAPIStreamBuffer { /// Build the base response object with current state fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse { + // If we have upstream metadata, use it as a base and update status/output + if let Some(upstream) = &self.upstream_response_metadata { + let mut response = upstream.clone(); + response.status = status; + // Don't update output here - will be set in finalize() + return response; + } + + // Fallback: build a minimal response from local state ResponsesAPIResponse { id: self.response_id.clone().unwrap_or_default(), object: "response".to_string(), @@ -293,24 +308,40 @@ impl ResponsesAPIStreamBuffer { // Build final response let mut output_items = Vec::new(); - // Add tool calls to output - for (item_id, arguments) in &self.function_arguments { - let output_index = self.output_items_added.iter() - .find(|(_, id)| *id == item_id) - .map(|(idx, _)| *idx) - .unwrap_or(0); + // Build complete output array by iterating through all output indices in order + let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1); - let (call_id, name) = self.tool_call_metadata.get(&output_index) - .cloned() - .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + for output_index in 0..=max_output_index { + if let Some(item_id) = self.output_items_added.get(&output_index) { + // Check if this is a function call + if let Some(arguments) = self.function_arguments.get(item_id) { + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); - output_items.push(OutputItem::FunctionCall { - id: item_id.clone(), - status: OutputItemStatus::Completed, - call_id, - name: Some(name), - arguments: Some(arguments.clone()), - }); + output_items.push(OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }); + } + // Check if this is a text message + else if let Some(text) = self.text_content.get(item_id) { + use crate::apis::openai_responses::OutputContent; + output_items.push(OutputItem::Message { + id: item_id.clone(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: text.clone(), + annotations: vec![], + logprobs: None, + }], + }); + } + } } let mut final_response = self.build_response(ResponseStatus::Completed); @@ -365,6 +396,24 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { let mut events = Vec::new(); + // Capture upstream metadata from ResponseCreated or ResponseInProgress if present + match stream_event { + ResponsesAPIStreamEvent::ResponseCreated { response, .. } | + ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => { + if self.upstream_response_metadata.is_none() { + // Store the full upstream response as our metadata template + self.upstream_response_metadata = Some(response.clone()); + // Also extract basic fields + self.response_id = Some(response.id.clone()); + self.model = Some(response.model.clone()); + self.created_at = Some(response.created_at); + } + // Don't emit these - we'll generate our own lifecycle events + return; + } + _ => {} + } + // Emit lifecycle events if not yet emitted if !self.created_emitted { // Initialize metadata from first event if needed diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 09ab262d..5a923329 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -193,6 +193,40 @@ impl SupportedAPIsFromClient { } } + +impl SupportedUpstreamAPIs { + /// Create a SupportedUpstreamApi from an endpoint path + pub fn from_endpoint(endpoint: &str) -> Option { + if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { + // Check if this is the Responses API endpoint + if openai_api == OpenAIApi::Responses { + return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api)); + } + // Otherwise it's ChatCompletions + return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api)); + } + + if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) { + return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api)); + } + + if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) { + match bedrock_api { + AmazonBedrockApi::Converse => { + return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api)) + } + AmazonBedrockApi::ConverseStream => { + return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api)) + } + } + } + + None + } + +} + + /// Get all supported endpoint paths pub fn supported_endpoints() -> Vec<&'static str> { let mut endpoints = Vec::new(); diff --git a/crates/hermesllm/src/transforms/response/mod.rs b/crates/hermesllm/src/transforms/response/mod.rs index 3ce75123..1dd0d4ea 100644 --- a/crates/hermesllm/src/transforms/response/mod.rs +++ b/crates/hermesllm/src/transforms/response/mod.rs @@ -1,3 +1,4 @@ //! Response transformation modules +pub mod output_to_input; pub mod to_anthropic; pub mod to_openai; diff --git a/crates/hermesllm/src/transforms/response/output_to_input.rs b/crates/hermesllm/src/transforms/response/output_to_input.rs new file mode 100644 index 00000000..8ab08205 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/output_to_input.rs @@ -0,0 +1,178 @@ +//! Conversions from response outputs to request inputs for conversation continuation +//! +//! This module provides utilities for converting OutputItem types from API responses +//! into InputItem types that can be used in subsequent requests. This is primarily used +//! for maintaining conversation history in the v1/responses API. + +use crate::apis::openai_responses::{ + InputContent, InputItem, InputMessage, MessageContent, MessageRole, OutputContent, OutputItem, +}; + +/// Converts an OutputItem from a response into an InputItem for the next request +/// This is used to build conversation history from previous responses +pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option { + match output { + // Convert output messages to input messages + OutputItem::Message { + role, content, .. + } => { + let input_content: Vec = content + .iter() + .filter_map(|c| match c { + OutputContent::OutputText { text, .. } => Some(InputContent::InputText { + text: text.clone(), + }), + OutputContent::OutputAudio { + data, .. + } => Some(InputContent::InputAudio { + data: data.clone(), + format: None, // Format not preserved in output + }), + OutputContent::Refusal { .. } => None, // Skip refusals + }) + .collect(); + + if input_content.is_empty() { + return None; + } + + // Map role string to MessageRole enum + let message_role = match role.as_str() { + "user" => MessageRole::User, + "assistant" => MessageRole::Assistant, + "system" => MessageRole::System, + "developer" => MessageRole::Developer, + _ => MessageRole::Assistant, // Default to assistant + }; + + Some(InputItem::Message(InputMessage { + role: message_role, + content: MessageContent::Items(input_content), + })) + } + // For function calls, we'll create an assistant message with the tool call info + // This matches how conversation history is typically built + OutputItem::FunctionCall { + name, arguments, .. + } => { + let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) { + format!("Called function: {} with arguments: {}", n, args) + } else { + "Called a function".to_string() + }; + + Some(InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: tool_call_text, + }]), + })) + } + // Skip other output types (tool outputs, etc.) as they don't convert to input + _ => None, + } +} + +/// Converts a Vec of OutputItems into InputItems for conversation continuation +pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec { + outputs + .iter() + .filter_map(convert_responses_output_to_input_items) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::openai_responses::{OutputItemStatus}; + + #[test] + fn test_output_message_to_input() { + let output = OutputItem::Message { + id: "msg_123".to_string(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: "Hello!".to_string(), + annotations: vec![], + logprobs: None, + }], + }; + + let input = convert_responses_output_to_input_items(&output).unwrap(); + + match input { + InputItem::Message(msg) => { + assert!(matches!(msg.role, MessageRole::Assistant)); + match &msg.content { + MessageContent::Items(items) => { + assert_eq!(items.len(), 1); + match &items[0] { + InputContent::InputText { text } => assert_eq!(text, "Hello!"), + _ => panic!("Expected InputText"), + } + } + _ => panic!("Expected MessageContent::Items"), + } + } + _ => panic!("Expected Message variant"), + } + } + + #[test] + fn test_function_call_to_input() { + let output = OutputItem::FunctionCall { + id: "fc_123".to_string(), + status: OutputItemStatus::Completed, + call_id: "call_123".to_string(), + name: Some("get_weather".to_string()), + arguments: Some(r#"{"location":"SF"}"#.to_string()), + }; + + let input = convert_responses_output_to_input_items(&output).unwrap(); + + match input { + InputItem::Message(msg) => { + assert!(matches!(msg.role, MessageRole::Assistant)); + match &msg.content { + MessageContent::Items(items) => { + match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); + } + _ => panic!("Expected InputText"), + } + } + _ => panic!("Expected MessageContent::Items"), + } + } + _ => panic!("Expected Message variant"), + } + } + + #[test] + fn test_outputs_to_inputs() { + let outputs = vec![ + OutputItem::Message { + id: "msg_1".to_string(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![OutputContent::OutputText { + text: "Hello".to_string(), + annotations: vec![], + logprobs: None, + }], + }, + OutputItem::FunctionCall { + id: "fc_1".to_string(), + status: OutputItemStatus::Completed, + call_id: "call_1".to_string(), + name: Some("test".to_string()), + arguments: Some("{}".to_string()), + }, + ]; + + let inputs = outputs_to_inputs(&outputs); + assert_eq!(inputs.len(), 2); + } +} diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index e26cc3b4..d90d9035 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -80,8 +80,19 @@ impl TryFrom for ResponsesAPIResponse { // Only add the message item if there's actual content (text, audio, or refusal) // Don't add empty message items when there are only tool calls if !content.is_empty() { + // Generate message ID: strip common prefixes to avoid double-prefixing + let message_id = if resp.id.starts_with("msg_") { + resp.id.clone() + } else if resp.id.starts_with("resp_") { + format!("msg_{}", &resp.id[5..]) // Strip "resp_" prefix + } else if resp.id.starts_with("chatcmpl-") { + format!("msg_{}", &resp.id[9..]) // Strip "chatcmpl-" prefix + } else { + format!("msg_{}", resp.id) + }; + items.push(OutputItem::Message { - id: format!("msg_{}", resp.id), + id: message_id, status: OutputItemStatus::Completed, role: match choice.message.role { Role::User => "user".to_string(), @@ -151,7 +162,12 @@ impl TryFrom for ResponsesAPIResponse { }; Ok(ResponsesAPIResponse { - id: resp.id, + // Generate proper resp_ prefixed ID if not already present + id: if resp.id.starts_with("resp_") { + resp.id + } else { + format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")) + }, object: "response".to_string(), created_at: resp.created as i64, status, @@ -942,7 +958,7 @@ mod tests { use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse}; let chat_response = ChatCompletionsResponse { - id: "chatcmpl-123".to_string(), + id: "resp_6de5512800cf4375a329a473a4f02879".to_string(), object: Some("chat.completion".to_string()), created: 1677652288, model: "gpt-4".to_string(), @@ -974,7 +990,9 @@ mod tests { let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); - assert_eq!(responses_api.id, "chatcmpl-123"); + // Response ID should be generated with resp_ prefix + assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'"); + assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID"); assert_eq!(responses_api.object, "response"); assert_eq!(responses_api.model, "gpt-4"); diff --git a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs index 9e2f083e..30b40956 100644 --- a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs +++ b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs @@ -58,11 +58,11 @@ impl TryFrom for ChatCompletionsStreamResponse { None, )), - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) + MessagesStreamEvent::ContentBlockStart { content_block, index } => { + convert_content_block_start(content_block, index) } - MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), + MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index), MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), @@ -272,6 +272,7 @@ impl TryFrom for ChatCompletionsStreamResponse { /// Convert content block start to OpenAI chunk fn convert_content_block_start( content_block: MessagesContentBlock, + index: u32, ) -> Result { match content_block { MessagesContentBlock::Text { .. } => { @@ -291,7 +292,7 @@ fn convert_content_block_start( refusal: None, function_call: None, tool_calls: Some(vec![ToolCallDelta { - index: 0, + index, id: Some(id), call_type: Some("function".to_string()), function: Some(FunctionCallDelta { @@ -313,6 +314,7 @@ fn convert_content_block_start( /// Convert content delta to OpenAI chunk fn convert_content_delta( delta: MessagesContentDelta, + index: u32, ) -> Result { match delta { MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( @@ -350,7 +352,7 @@ fn convert_content_delta( refusal: None, function_call: None, tool_calls: Some(vec![ToolCallDelta { - index: 0, + index, id: None, call_type: None, function: Some(FunctionCallDelta { diff --git a/docs/db_setup/README.md b/docs/db_setup/README.md new file mode 100644 index 00000000..34aff973 --- /dev/null +++ b/docs/db_setup/README.md @@ -0,0 +1,109 @@ +# Database Setup for Conversation State Storage + +This directory contains SQL scripts needed to set up database tables for storing conversation state when using the OpenAI Responses API. + +## Prerequisites + +- PostgreSQL database (Supabase or self-hosted) +- Database connection credentials +- `psql` CLI tool or database admin access + +## Setup Instructions + +### Option 1: Using psql + +```bash +psql $DATABASE_URL -f docs/db_setup/conversation_states.sql +``` + +### Option 2: Using Supabase Dashboard + +1. Log in to your Supabase project dashboard +2. Navigate to the SQL Editor +3. Copy and paste the contents of `conversation_states.sql` +4. Run the query + +### Option 3: Direct Database Connection + +Connect to your PostgreSQL database using your preferred client and execute the SQL from `conversation_states.sql`. + +## Verification + +After running the setup, verify the table was created: + +```sql +SELECT tablename FROM pg_tables WHERE tablename = 'conversation_states'; +``` + +You should see `conversation_states` in the results. + +## Configuration + +After setting up the database table, configure your application to use Supabase storage by setting the appropriate environment variable or configuration parameter with your database connection string. + +### Supabase Connection String + +**Important:** Supabase requires different connection strings depending on your network: + +- **IPv4 Networks (Most Common)**: Use the **Session Pooler** connection string (port 5432): + ``` + postgresql://postgres.[PROJECT-REF]:[PASSWORD]@aws-0-[REGION].pooler.supabase.com:5432/postgres + ``` + +- **IPv6 Networks**: Use the direct connection (port 5432): + ``` + postgresql://postgres:[PASSWORD]@db.[PROJECT-REF].supabase.co:5432/postgres + ``` + +**How to get your connection string:** +1. Go to your Supabase project dashboard +2. Settings → Database → Connection Pooling +3. Copy the **Session mode** connection string +4. Replace `[YOUR-PASSWORD]` with your actual database password +5. URL-encode special characters in the password (e.g., `#` becomes `%23`) + +**Example:** +```bash +# If your password is "MyPass#123", encode it as "MyPass%23123" +export DATABASE_URL="postgresql://postgres.myproject:MyPass%23123@aws-0-us-west-2.pooler.supabase.com:5432/postgres" +``` + +### Testing the Connection + +To test your connection string works: +```bash +export TEST_DATABASE_URL="your-connection-string-here" +cd crates/brightstaff +cargo test supabase -- --nocapture +``` + +## Table Schema + +The `conversation_states` table stores: +- `response_id` (TEXT, PRIMARY KEY): Unique identifier for each conversation +- `input_items` (JSONB): Array of conversation messages and context +- `created_at` (BIGINT): Unix timestamp when conversation started +- `model` (TEXT): Model name used for the conversation +- `provider` (TEXT): LLM provider name +- `updated_at` (TIMESTAMP): Last update time (auto-managed) + +## Maintenance + +### Cleanup Old Conversations + +To prevent unbounded growth, consider periodically cleaning up old conversation states: + +```sql +-- Delete conversations older than 7 days +DELETE FROM conversation_states +WHERE updated_at < NOW() - INTERVAL '7 days'; +``` + +You can automate this with a cron job or database trigger. + +## Troubleshooting + +If you encounter errors on first use: +- **"Table 'conversation_states' does not exist"**: Run the setup SQL +- **Connection errors**: Verify your DATABASE_URL is correct +- **Permission errors**: Ensure your database user has CREATE TABLE privileges diff --git a/docs/db_setup/conversation_states.sql b/docs/db_setup/conversation_states.sql new file mode 100644 index 00000000..26272423 --- /dev/null +++ b/docs/db_setup/conversation_states.sql @@ -0,0 +1,31 @@ +-- Conversation State Storage Table +-- This table stores conversational context for the OpenAI Responses API +-- Run this SQL against your PostgreSQL/Supabase database before enabling conversation state storage + +CREATE TABLE IF NOT EXISTS conversation_states ( + response_id TEXT PRIMARY KEY, + input_items JSONB NOT NULL, + created_at BIGINT NOT NULL, + model TEXT NOT NULL, + provider TEXT NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Indexes for common query patterns +CREATE INDEX IF NOT EXISTS idx_conversation_states_created_at + ON conversation_states(created_at); + +CREATE INDEX IF NOT EXISTS idx_conversation_states_provider + ON conversation_states(provider); + +-- Optional: Add a policy for automatic cleanup of old conversations +-- Uncomment and adjust the retention period as needed +-- CREATE INDEX IF NOT EXISTS idx_conversation_states_updated_at +-- ON conversation_states(updated_at); + +COMMENT ON TABLE conversation_states IS 'Stores conversation history for OpenAI Responses API continuity'; +COMMENT ON COLUMN conversation_states.response_id IS 'Unique identifier for the conversation state'; +COMMENT ON COLUMN conversation_states.input_items IS 'JSONB array of conversation messages and context'; +COMMENT ON COLUMN conversation_states.created_at IS 'Unix timestamp (seconds) when the conversation started'; +COMMENT ON COLUMN conversation_states.model IS 'Model name used for this conversation'; +COMMENT ON COLUMN conversation_states.provider IS 'LLM provider (e.g., openai, anthropic, bedrock)'; diff --git a/docs/source/resources/includes/arch_config_state_storage_example.yaml b/docs/source/resources/includes/arch_config_state_storage_example.yaml new file mode 100644 index 00000000..27a417c0 --- /dev/null +++ b/docs/source/resources/includes/arch_config_state_storage_example.yaml @@ -0,0 +1,32 @@ +version: v0.1 + +listeners: + egress_traffic: + address: 0.0.0.0 + port: 12000 + message_format: openai + timeout: 30s + +llm_providers: + + # OpenAI Models + - model: openai/gpt-5-mini-2025-08-07 + access_key: $OPENAI_API_KEY + default: true + + # Anthropic Models + - model: anthropic/claude-sonnet-4-20250514 + access_key: $ANTHROPIC_API_KEY + +# State storage configuration for v1/responses API +# Manages conversation state for multi-turn conversations +state_storage: + # Type: memory | postgres + type: postgres + + # Connection string for postgres type + # Environment variables are supported using $VAR_NAME or ${VAR_NAME} syntax + # Replace [USER] and [HOST] with your actual database credentials + # Variables like $DB_PASSWORD MUST be set before running config validation/rendering + # Example: Replace [USER] with 'myuser' and [HOST] with 'db.example.com:5432' + connection_string: "postgresql://[USER]:$DB_PASSWORD@[HOST]:5432/postgres" diff --git a/tests/e2e/arch_config_memory_state_v1_responses.yaml b/tests/e2e/arch_config_memory_state_v1_responses.yaml new file mode 100644 index 00000000..afc40910 --- /dev/null +++ b/tests/e2e/arch_config_memory_state_v1_responses.yaml @@ -0,0 +1,25 @@ +version: v0.1 + +listeners: + egress_traffic: + address: 0.0.0.0 + port: 12000 + message_format: openai + timeout: 30s + +llm_providers: + + # OpenAI Models + - model: openai/gpt-5-mini-2025-08-07 + access_key: $OPENAI_API_KEY + default: true + + # Anthropic Models + - model: anthropic/claude-sonnet-4-20250514 + access_key: $ANTHROPIC_API_KEY + +# State storage configuration for v1/responses API +# Manages conversation state for multi-turn conversations +state_storage: + # Type: memory | postgres + type: memory diff --git a/tests/e2e/run_e2e_tests.sh b/tests/e2e/run_e2e_tests.sh index f60f79bc..a6e66121 100644 --- a/tests/e2e/run_e2e_tests.sh +++ b/tests/e2e/run_e2e_tests.sh @@ -69,6 +69,14 @@ log running e2e tests for openai responses api client log ======================================== poetry run pytest test_openai_responses_api_client.py +log startup arch gateway with state storage for openai responses api client demo +archgw down +archgw up arch_config_memory_state_v1_responses.yaml + +log running e2e tests for openai responses api client +log ======================================== +poetry run pytest test_openai_responses_api_client_with_state.py + log shutting down the weather_forecast demo log ======================================= cd ../../demos/samples_python/weather_forecast diff --git a/tests/e2e/test_openai_responses_api_client_with_state.py b/tests/e2e/test_openai_responses_api_client_with_state.py new file mode 100644 index 00000000..c23307e6 --- /dev/null +++ b/tests/e2e/test_openai_responses_api_client_with_state.py @@ -0,0 +1,218 @@ +import openai +import pytest +import os +import logging +import sys + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +LLM_GATEWAY_ENDPOINT = os.getenv( + "LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions" +) + + +def test_conversation_state_management_two_turn(): + """ + Test conversation state management across two turns: + 1. Send initial message to non-OpenAI model via v1/responses + 2. Capture response_id from first response + 3. Send second message with previous_response_id + 4. Verify model receives both messages in correct order + """ + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + logger.info("\n" + "=" * 80) + logger.info("TEST: Conversation State Management - Two Turn Flow") + logger.info("=" * 80) + + # Turn 1: Send initial message to Anthropic (non-OpenAI model) + logger.info("\n[TURN 1] Sending initial message...") + resp1 = client.responses.create( + model="claude-sonnet-4-20250514", + input="My name is Alice and I like pizza.", + ) + + # Extract response_id from first response + response_id_1 = resp1.id + logger.info(f"[TURN 1] Received response_id: {response_id_1}") + logger.info(f"[TURN 1] Model response: {resp1.output_text}") + + assert response_id_1 is not None, "First response should have an id" + assert len(resp1.output_text) > 0, "First response should have content" + + # Turn 2: Send follow-up message with previous_response_id + # Ask the model to list all messages to verify state was combined + logger.info( + f"\n[TURN 2] Sending follow-up with previous_response_id={response_id_1}" + ) + resp2 = client.responses.create( + model="claude-sonnet-4-20250514", + input="Please list all the messages you have received in our conversation, numbering each one.", + previous_response_id=response_id_1, + ) + + response_id_2 = resp2.id + logger.info(f"[TURN 2] Received response_id: {response_id_2}") + logger.info(f"[TURN 2] Model response: {resp2.output_text}") + + assert response_id_2 is not None, "Second response should have an id" + assert response_id_2 != response_id_1, "Second response should have different id" + + # Verify the model received the conversation history + # The response should reference both the initial message and the follow-up + response_lower = resp2.output_text.lower() + + # Check if the model acknowledges receiving multiple messages + # Different models might format this differently, so we check for various indicators + has_conversation_context = ( + "alice" in response_lower + or "pizza" in response_lower # References the name from turn 1 + or "two" in response_lower # References the preference from turn 1 + or "2" in response_lower # Mentions number of messages + or "first" in response_lower # Numeric indicator + or "second" # References first message + in response_lower # References second message + ) + + logger.info( + f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}" + ) + logger.info( + f"[VALIDATION] Response contains conversation markers: {has_conversation_context}" + ) + + print(f"\n{'='*80}") + print("Conversation State Test Results:") + print(f"Turn 1 Response ID: {response_id_1}") + print(f"Turn 2 Response ID: {response_id_2}") + print(f"Turn 1 Output: {resp1.output_text[:100]}...") + print(f"Turn 2 Output: {resp2.output_text}") + print(f"Conversation Context Preserved: {has_conversation_context}") + print(f"{'='*80}\n") + + assert has_conversation_context, ( + f"Model should have received conversation history. " + f"Response: {resp2.output_text}" + ) + + +def test_conversation_state_management_two_turn_streaming(): + """ + Test conversation state management across two turns with streaming: + 1. Send initial streaming message to non-OpenAI model via v1/responses + 2. Capture response_id from first response + 3. Send second streaming message with previous_response_id + 4. Verify model receives both messages in correct order + """ + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + logger.info("\n" + "=" * 80) + logger.info("TEST: Conversation State Management - Two Turn Streaming Flow") + logger.info("=" * 80) + + # Turn 1: Send initial streaming message to Anthropic (non-OpenAI model) + logger.info("\n[TURN 1] Sending initial streaming message...") + stream1 = client.responses.create( + model="claude-sonnet-4-20250514", + input="My name is Alice and I like pizza.", + stream=True, + ) + + # Collect streamed content and capture response_id + text_chunks_1 = [] + response_id_1 = None + + for event in stream1: + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + text_chunks_1.append(event.delta) + + # Capture response_id from response.completed event + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + response_id_1 = event.response.id + + output_1 = "".join(text_chunks_1) + logger.info(f"[TURN 1] Received response_id: {response_id_1}") + logger.info(f"[TURN 1] Model response: {output_1}") + + assert response_id_1 is not None, "First response should have an id" + assert len(output_1) > 0, "First response should have content" + + # Turn 2: Send follow-up streaming message with previous_response_id + logger.info( + f"\n[TURN 2] Sending follow-up streaming request with previous_response_id={response_id_1}" + ) + stream2 = client.responses.create( + model="claude-sonnet-4-20250514", + input="Please list all the messages you have received in our conversation, numbering each one.", + previous_response_id=response_id_1, + stream=True, + ) + + # Collect streamed content from second response + text_chunks_2 = [] + response_id_2 = None + + for event in stream2: + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + text_chunks_2.append(event.delta) + + # Capture response_id from response.completed event + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + response_id_2 = event.response.id + + output_2 = "".join(text_chunks_2) + logger.info(f"[TURN 2] Received response_id: {response_id_2}") + logger.info(f"[TURN 2] Model response: {output_2}") + + assert response_id_2 is not None, "Second response should have an id" + assert response_id_2 != response_id_1, "Second response should have different id" + + # Verify the model received the conversation history + response_lower = output_2.lower() + + # Check if the model acknowledges receiving multiple messages + has_conversation_context = ( + "alice" in response_lower + or "pizza" in response_lower # References the name from turn 1 + or "two" in response_lower # References the preference from turn 1 + or "2" in response_lower # Mentions number of messages + or "first" in response_lower # Numeric indicator + or "second" # References first message + in response_lower # References second message + ) + + logger.info( + f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}" + ) + logger.info( + f"[VALIDATION] Response contains conversation markers: {has_conversation_context}" + ) + + print(f"\n{'='*80}") + print("Streaming Conversation State Test Results:") + print(f"Turn 1 Response ID: {response_id_1}") + print(f"Turn 2 Response ID: {response_id_2}") + print(f"Turn 1 Output: {output_1[:100]}...") + print(f"Turn 2 Output: {output_2}") + print(f"Conversation Context Preserved: {has_conversation_context}") + print(f"{'='*80}\n") + + assert has_conversation_context, ( + f"Model should have received conversation history. " f"Response: {output_2}" + )