first commit with tests to enable state mamangement via memory

This commit is contained in:
Salman Paracha 2025-12-14 22:21:00 -08:00
parent a79f55f313
commit bce917c9d4
16 changed files with 1951 additions and 66 deletions

19
crates/Cargo.lock generated
View file

@ -308,11 +308,13 @@ name = "brightstaff"
version = "0.1.0"
dependencies = [
"async-openai",
"async-trait",
"bytes",
"chrono",
"common",
"eventsource-client",
"eventsource-stream",
"flate2",
"futures",
"futures-util",
"hermesllm",
@ -707,6 +709,16 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -1533,6 +1545,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
@ -2650,6 +2663,12 @@ dependencies = [
"libc",
]
[[package]]
name = "simd-adler32"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
[[package]]
name = "similar"
version = "2.7.0"

View file

@ -5,11 +5,13 @@ edition = "2021"
[dependencies]
async-openai = "0.30.1"
async-trait = "0.1"
bytes = "1.10.1"
chrono = "0.4"
common = { version = "0.1.0", path = "../common", features = ["trace-collection"] }
eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
flate2 = "1.0"
futures = "0.3.31"
futures-util = "0.3.31"
hermesllm = { version = "0.1.0", path = "../hermesllm" }

View file

@ -1,8 +1,9 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::traces::TraceCollector;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
@ -16,6 +17,11 @@ use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
};
use crate::tracing::operation_component;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
@ -31,14 +37,20 @@ pub async fn llm_chat(
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
trace_collector: Arc<TraceCollector>,
state_storage: Arc<dyn StateStorage>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string());
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = request_headers
.get("traceparent")
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
@ -51,7 +63,8 @@ pub async fn llm_chat(
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | REQUEST BODY (raw utf8): {}",
request_id,
String::from_utf8_lossy(&chat_request_bytes)
);
@ -61,14 +74,19 @@ pub async fn llm_chat(
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
warn!("[PLANO_REQ_ID:{}] | BRIGHTSTAFF | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | BRIGHTSTAFF | Failed to parse request: {}", request_id, err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
@ -83,9 +101,76 @@ pub async fn llm_chat(
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("Removed archgw_preference_config from metadata");
debug!("[PLANO (BRIGHTSTAFF)] Removed archgw_preference_config from metadata");
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
let mut should_manage_state = false;
if is_responses_api_client {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
// Get the upstream path and check if it's ResponsesAPI
let upstream_path = get_upstream_path(
&llm_providers,
&resolved_model,
&request_path,
&resolved_model,
is_streaming_request,
).await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
.await
{
Ok(combined_input) => {
// Update both the request and original_input_items
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
debug!("[PLANO (BRIGHTSTAFF)] Updated request with conversation history ({} items)", original_input_items.len());
}
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!("[PLANO (BRIGHTSTAFF)] Previous response_id not found: {}", prev_resp_id);
let err_msg = format!(
"[PLANO (BRIGHTSTAFF)] Conversation state not found for previous_response_id: {}",
prev_resp_id
);
let mut conflict_response = Response::new(full(err_msg));
*conflict_response.status_mut() = StatusCode::CONFLICT;
return Ok(conflict_response);
}
Err(e) => {
// Log warning but continue on other storage errors
warn!(
"Failed to retrieve conversation state for {}: {}",
prev_resp_id, e
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
} else {
debug!("[PLANO (BRIGHTSTAFF)] Upstream supports ResponsesAPI natively, passing through without state management");
}
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
@ -110,7 +195,7 @@ pub async fn llm_chat(
let model_name = routing_result.model_name;
debug!(
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
"[PLANO ARCH_ROUTER] URL: {}, Resolved Model: {}",
full_qualified_llm_provider_url, model_name
);
@ -173,15 +258,40 @@ pub async fn llm_chat(
&llm_providers,
).await;
// Use PassthroughProcessor to track streaming metrics and finalize the span
let processor = ObservableStreamProcessor::new(
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
trace_collector,
operation_component::LLM,
llm_span,
request_start_time,
);
let streaming_response = create_streaming_response(byte_stream, processor, 16);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI)
let streaming_response = if should_manage_state && !original_input_items.is_empty() {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage,
original_input_items,
resolved_model.clone(),
model_name.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id.clone(),
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
@ -301,35 +411,7 @@ async fn get_upstream_path(
resolved_model: &str,
is_streaming: bool,
) -> String {
let providers_lock = llm_providers.read().await;
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
});
let (provider_id, base_url_path_prefix) = if let Some(provider) = provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
};
drop(providers_lock);
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
// Calculate the upstream path using the proper API
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
@ -343,3 +425,37 @@ async fn get_upstream_path(
base_url_path_prefix.as_deref(),
)
}
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
async fn get_provider_info(
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
});
if let Some(provider) = provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

View file

@ -1,4 +1,5 @@
pub mod handlers;
pub mod router;
pub mod state;
pub mod tracing;
pub mod utils;

View file

@ -3,6 +3,8 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
use brightstaff::router::llm_router::RouterService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::Configuration;
@ -101,6 +103,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
let _flusher_handle = trace_collector.clone().start_background_flusher();
// Initialize conversation state storage for v1/responses
// TODO: Make this configurable (MEMORY vs SUPABASE) via arch_config.yaml
let state_storage: Arc<dyn StateStorage> = Arc::new(MemoryConversationalStorage::new());
info!("Initialized conversation state storage: Memory");
loop {
let (stream, _) = listener.accept().await?;
@ -115,6 +122,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let trace_collector = trace_collector.clone();
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
@ -124,13 +132,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let trace_collector = trace_collector.clone();
let state_storage = state_storage.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector)
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
.with_context(parent_cx)
.await
}

View file

@ -0,0 +1,584 @@
use super::{OpenAIConversationState, StateStorage, StateStorageError};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
/// In-memory storage backend for conversation state
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
#[derive(Clone)]
pub struct MemoryConversationalStorage {
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
}
impl MemoryConversationalStorage {
pub fn new() -> Self {
Self {
storage: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemoryConversationalStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStorage for MemoryConversationalStorage {
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
let response_id = state.response_id.clone();
let mut storage = self.storage.write().await;
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
response_id, state.model, state.provider, state.input_items.len()
);
storage.insert(response_id, state);
Ok(())
}
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
let storage = self.storage.read().await;
match storage.get(response_id) {
Some(state) => {
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Retrieved conversation state: input_items={}",
response_id, state.input_items.len()
);
Ok(state.clone())
}
None => {
warn!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Conversation state not found",
response_id
);
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
let storage = self.storage.read().await;
Ok(storage.contains_key(response_id))
}
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
let mut storage = self.storage.write().await;
if storage.remove(response_id).is_some() {
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
response_id
);
Ok(())
} else {
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new();
for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
content: vec![InputContent::InputText {
text: format!("Message {}", i),
}],
}));
}
OpenAIConversationState {
response_id: response_id.to_string(),
input_items,
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
}
}
#[tokio::test]
async fn test_put_and_get_success() {
let storage = MemoryConversationalStorage::new();
let state: OpenAIConversationState = create_test_state("resp_001", 3);
// Store
storage.put(state.clone()).await.unwrap();
// Retrieve
let retrieved = storage.get("resp_001").await.unwrap();
assert_eq!(retrieved.response_id, state.response_id);
assert_eq!(retrieved.model, state.model);
assert_eq!(retrieved.provider, state.provider);
assert_eq!(retrieved.input_items.len(), 3);
assert_eq!(retrieved.created_at, state.created_at);
}
#[tokio::test]
async fn test_put_overwrites_existing() {
let storage = MemoryConversationalStorage::new();
// First state
let state1 = create_test_state("resp_002", 2);
storage.put(state1).await.unwrap();
// Overwrite with new state
let state2 = OpenAIConversationState {
response_id: "resp_002".to_string(),
input_items: vec![],
created_at: 9999999999,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
storage.put(state2.clone()).await.unwrap();
// Should retrieve the new state
let retrieved = storage.get("resp_002").await.unwrap();
assert_eq!(retrieved.model, "gpt-4");
assert_eq!(retrieved.provider, "openai");
assert_eq!(retrieved.input_items.len(), 0);
assert_eq!(retrieved.created_at, 9999999999);
}
#[tokio::test]
async fn test_get_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.get("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_exists_returns_false_for_nonexistent() {
let storage = MemoryConversationalStorage::new();
assert!(!storage.exists("resp_003").await.unwrap());
}
#[tokio::test]
async fn test_exists_returns_true_after_put() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_004", 1);
assert!(!storage.exists("resp_004").await.unwrap());
storage.put(state).await.unwrap();
assert!(storage.exists("resp_004").await.unwrap());
}
#[tokio::test]
async fn test_delete_success() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_005", 2);
storage.put(state).await.unwrap();
assert!(storage.exists("resp_005").await.unwrap());
// Delete
storage.delete("resp_005").await.unwrap();
// Should no longer exist
assert!(!storage.exists("resp_005").await.unwrap());
assert!(storage.get("resp_005").await.is_err());
}
#[tokio::test]
async fn test_delete_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.delete("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_merge_combines_inputs() {
let storage = MemoryConversationalStorage::new();
// Create a previous state with 2 messages
let prev_state = create_test_state("resp_006", 2);
// Create current input with 1 message
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "New message".to_string(),
}],
})];
// Merge
let merged = storage.merge(&prev_state, current_input);
// Should have 3 messages total (2 from prev + 1 current)
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_preserves_order() {
let storage = MemoryConversationalStorage::new();
// Previous state has messages 0 and 1
let prev_state = create_test_state("resp_007", 2);
// Current input has message 2
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Message 2".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0];
match &msg.content[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg) = &merged[2];
match &msg.content[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_with_empty_current_input() {
let storage = MemoryConversationalStorage::new();
let prev_state = create_test_state("resp_008", 3);
let merged = storage.merge(&prev_state, vec![]);
// Should just have the previous state's items
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_with_empty_previous_state() {
let storage = MemoryConversationalStorage::new();
let prev_state = OpenAIConversationState {
response_id: "resp_009".to_string(),
input_items: vec![],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Only message".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Should just have the current input
assert_eq!(merged.len(), 1);
}
#[tokio::test]
async fn test_concurrent_access() {
let storage = MemoryConversationalStorage::new();
// Spawn multiple tasks that write concurrently
let mut handles = vec![];
for i in 0..10 {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
let state = create_test_state(&format!("resp_{}", i), i % 3);
storage_clone.put(state).await.unwrap();
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Verify all states were stored
for i in 0..10 {
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
}
}
#[tokio::test]
async fn test_multiple_operations_on_same_id() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_010", 1);
// Put
storage.put(state.clone()).await.unwrap();
// Get
let retrieved = storage.get("resp_010").await.unwrap();
assert_eq!(retrieved.response_id, "resp_010");
// Exists
assert!(storage.exists("resp_010").await.unwrap());
// Put again (overwrite)
let new_state = create_test_state("resp_010", 5);
storage.put(new_state).await.unwrap();
// Get updated
let updated = storage.get("resp_010").await.unwrap();
assert_eq!(updated.input_items.len(), 5);
// Delete
storage.delete("resp_010").await.unwrap();
// Should not exist
assert!(!storage.exists("resp_010").await.unwrap());
}
#[tokio::test]
async fn test_merge_with_tool_call_flow() {
// This test simulates a realistic tool call conversation flow:
// 1. User sends message: "What's the weather?"
// 2. Model responds with function call (converted to assistant message)
// 3. User sends function call output in next request with previous_response_id
// The merge should combine: user message + assistant function call + function output
let storage = MemoryConversationalStorage::new();
// Step 1: Previous state contains the initial exchange
// - User message: "What's the weather in SF?"
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
let prev_state = OpenAIConversationState {
response_id: "resp_tool_001".to_string(),
input_items: vec![
// Original user message
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather in San Francisco?".to_string(),
}],
}),
// Assistant's function call (converted from OutputItem::FunctionCall)
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
}],
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Step 2: Current request includes function call output
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
}],
})];
// Step 3: Merge should combine all conversation history
let merged = storage.merge(&prev_state, current_input);
// Should have 3 items: user question + assistant function call + function output
assert_eq!(merged.len(), 3);
// Verify the order and content
let InputItem::Message(msg1) = &merged[0];
assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("weather in San Francisco"));
}
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg2) = &merged[1];
assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg3) = &merged[2];
assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("Function result"));
assert!(text.contains("temperature"));
}
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_with_multiple_tool_calls() {
// Test a more complex scenario with multiple tool calls
let storage = MemoryConversationalStorage::new();
// Previous state has: user message + 2 function calls from assistant
let prev_state = OpenAIConversationState {
response_id: "resp_tool_002".to_string(),
input_items: vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather and time in SF?".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
}],
}),
],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
// Current input: function outputs for both calls
let current_input = vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Weather result: {\"temp\": 68}".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Time result: {\"time\": \"14:30\"}".to_string(),
}],
}),
];
let merged = storage.merge(&prev_state, current_input);
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
assert_eq!(merged.len(), 5);
// Verify first item is original user message
let InputItem::Message(first) = &merged[0];
assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3];
assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content[0] {
InputContent::InputText { text } => assert!(text.contains("Weather result")),
_ => panic!("Expected InputText"),
}
let InputItem::Message(last) = &merged[4];
assert!(matches!(last.role, MessageRole::User));
match &last.content[0] {
InputContent::InputText { text } => assert!(text.contains("Time result")),
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_preserves_conversation_context_for_multi_turn() {
// Simulate a multi-turn conversation with tool calls
let storage = MemoryConversationalStorage::new();
// Previous state: full conversation history up to this point
let prev_state = OpenAIConversationState {
response_id: "resp_tool_003".to_string(),
input_items: vec![
// Turn 1: User asks about weather
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather?".to_string(),
}],
}),
// Turn 1: Assistant calls get_weather
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather".to_string(),
}],
}),
// Turn 2: User provides function output
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Weather: sunny, 72°F".to_string(),
}],
}),
// Turn 2: Assistant responds with text
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "It's sunny and 72°F in San Francisco today!".to_string(),
}],
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Turn 3: User asks follow-up question
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Should I bring an umbrella?".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Should have all 5 messages in order
assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0];
match &first.content[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
_ => panic!("Expected InputText"),
}
let InputItem::Message(last) = &merged[4];
match &last.content[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")),
_ => panic!("Expected InputText"),
}
}
}

View file

@ -0,0 +1,157 @@
use async_trait::async_trait;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageRole, InputParam};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use tracing::{debug, info};
pub mod memory;
pub mod response_state_processor;
pub mod supabase;
/// Represents the conversational state for a v1/responses request
/// Contains the complete input/output history that can be restored
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConversationState {
/// The response ID this state is associated with
pub response_id: String,
/// The complete input history (original input + accumulated outputs)
/// This is what gets prepended to new requests via previous_response_id
pub input_items: Vec<InputItem>,
/// Timestamp when this state was created
pub created_at: i64,
/// Model used for this response
pub model: String,
/// Provider that generated this response (e.g., "anthropic", "openai")
pub provider: String,
}
/// Error types for state storage operations
#[derive(Debug)]
pub enum StateStorageError {
/// State not found for given response_id
NotFound(String),
/// Storage backend error (network, database, etc.)
StorageError(String),
/// Serialization/deserialization error
SerializationError(String),
}
impl fmt::Display for StateStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
}
}
}
impl Error for StateStorageError {}
/// Trait for conversation state storage backends
#[async_trait]
pub trait StateStorage: Send + Sync {
/// Store conversation state for a response
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>;
/// Retrieve conversation state by response_id
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError>;
/// Check if state exists for a response_id
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError>;
/// Delete state for a response_id (optional, for cleanup)
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
fn merge(
&self,
prev_state: &OpenAIConversationState,
current_input: Vec<InputItem>,
) -> Vec<InputItem> {
// Default implementation: prepend previous input, append current
let prev_count = prev_state.input_items.len();
let current_count = current_input.len();
let mut combined_input = prev_state.input_items.clone();
combined_input.extend(current_input);
debug!(
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
prev_state.response_id,
prev_count,
current_count,
combined_input.len(),
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
);
combined_input
}
}
/// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
Memory,
Supabase,
}
impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),
_ => None,
}
}
}
// === Utility functions for state management ===
/// Extract input items from InputParam, converting text to structured format
pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
match input {
InputParam::Text(text) => {
vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: text.clone(),
}],
})]
}
InputParam::Items(items) => items.clone(),
}
}
/// Retrieve previous conversation state and combine with current input
/// Returns combined input if previous state found, or original input if not found/error
pub async fn retrieve_and_combine_input(
storage: Arc<dyn StateStorage>,
previous_response_id: &str,
current_input: Vec<InputItem>,
) -> Result<Vec<InputItem>, StateStorageError> {
info!(
"Retrieving conversation state for previous_response_id: {}",
previous_response_id
);
// First get the previous state
let prev_state = storage.get(previous_response_id).await?;
let combined_input = storage.merge(&prev_state, current_input);
debug!(
"Retrieved and merged conversation state: {} total input items",
combined_input.len()
);
Ok(combined_input)
}

View file

@ -0,0 +1,307 @@
use bytes::Bytes;
use flate2::read::GzDecoder;
use hermesllm::apis::openai_responses::{
InputItem, OutputItem, ResponsesAPIStreamEvent,
};
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
use std::io::Read;
use std::sync::Arc;
use tracing::{info, debug, warn};
use crate::handlers::utils::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
/// Processor that wraps another processor and handles v1/responses state management
/// Captures response_id and output from streaming responses, stores state after completion
pub struct ResponsesStateProcessor<P: StreamProcessor> {
/// The underlying processor (e.g., ObservableStreamProcessor for metrics)
inner: P,
/// State storage backend
storage: Arc<dyn StateStorage>,
/// Original input items from the request
original_input: Vec<InputItem>,
/// Model name
model: String,
/// Provider name
provider: String,
/// Whether this is a streaming request
is_streaming: bool,
/// Whether upstream is OpenAI (skip storage if true)
is_openai_upstream: bool,
/// Content-Encoding header value (e.g., "gzip", "br", None)
content_encoding: Option<String>,
/// Request ID for logging
request_id: String,
/// Buffer for accumulating chunks (needed for non-streaming compressed responses)
chunk_buffer: Vec<u8>,
/// Captured response_id from response.completed event
response_id: Option<String>,
/// Captured output items from response.completed event
output_items: Option<Vec<OutputItem>>,
}
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
pub fn new(
inner: P,
storage: Arc<dyn StateStorage>,
original_input: Vec<InputItem>,
model: String,
provider: String,
is_streaming: bool,
is_openai_upstream: bool,
content_encoding: Option<String>,
request_id: String,
) -> Self {
Self {
inner,
storage,
original_input,
model,
provider,
is_streaming,
is_openai_upstream,
content_encoding,
request_id,
chunk_buffer: Vec::new(),
response_id: None,
output_items: None,
}
}
/// Decompress accumulated buffer based on Content-Encoding header
fn decompress_buffer(&self) -> Vec<u8> {
if self.chunk_buffer.is_empty() {
return Vec::new();
}
match self.content_encoding.as_deref() {
Some("gzip") => {
let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice());
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => {
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes",
self.request_id,
self.chunk_buffer.len(),
decompressed.len()
);
decompressed
}
Err(e) => {
warn!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to decompress gzip buffer: {}",
self.request_id,
e
);
self.chunk_buffer.clone()
}
}
}
Some(encoding) => {
warn!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.",
self.request_id,
encoding
);
self.chunk_buffer.clone()
}
None => self.chunk_buffer.clone(),
}
}
/// Parse response to extract response_id and output
/// For streaming: parse SSE events looking for response.completed (per chunk)
/// For non-streaming: buffer all chunks, then decompress and parse on completion
fn try_parse_response_chunk(&mut self, chunk: &[u8]) {
if self.is_streaming {
// Streaming: Try to parse SSE events from this chunk
// Note: For compressed streaming, we'd need to buffer and decompress first
// but most streaming responses aren't compressed since SSE needs to be readable
let sse_iter = match SseStreamIter::try_from(chunk) {
Ok(iter) => iter,
Err(_) => return, // Not valid SSE format, skip
};
// Process each SSE event in the chunk, looking for data lines with response.completed
for event in sse_iter {
// Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
// Check if this is a ResponseCompleted event
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}, output_json={}",
self.request_id,
response.id,
response.output.len(),
serde_json::to_string(&response.output).unwrap_or_else(|_| "serialization_error".to_string())
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
}
}
}
} else {
// Non-streaming: Buffer chunks, will decompress and parse on completion
self.chunk_buffer.extend_from_slice(chunk);
}
}
/// Parse buffered non-streaming response (called on completion)
fn try_parse_buffered_response(&mut self) {
if self.is_streaming || self.chunk_buffer.is_empty() {
return;
}
// Decompress if needed
let decompressed = self.decompress_buffer();
// Parse complete JSON response
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
Ok(response) => {
info!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
}
Err(e) => {
// Log parse error with chunk preview for debugging
let chunk_preview = String::from_utf8_lossy(&decompressed);
let preview_len = chunk_preview.len().min(200);
warn!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}",
self.request_id,
e,
preview_len,
&chunk_preview[..preview_len]
);
}
}
}
}
impl<P: StreamProcessor> StreamProcessor for ResponsesStateProcessor<P> {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
// Buffer/parse chunk for response extraction
self.try_parse_response_chunk(&chunk);
// Forward to inner processor
self.inner.process_chunk(chunk)
}
fn on_first_bytes(&mut self) {
self.inner.on_first_bytes();
}
fn on_complete(&mut self) {
// For non-streaming, decompress and parse buffered response
self.try_parse_buffered_response();
// First, let the inner processor complete
self.inner.on_complete();
// Skip storage for OpenAI upstream
if self.is_openai_upstream {
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider",
self.request_id
);
return;
}
// Store state if we captured response_id and output
if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) {
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Output items before conversion: {}",
self.request_id,
serde_json::to_string(&output_items).unwrap_or_else(|_| "serialization_error".to_string())
);
// Convert output items to input items for next request
let output_as_inputs = outputs_to_inputs(output_items);
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}",
self.request_id, output_items.len(), output_as_inputs.len()
);
// Combine original input + output as new input history
let mut combined_input = self.original_input.clone();
combined_input.extend(output_as_inputs);
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}",
self.request_id,
self.original_input.len(),
combined_input.len(),
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
);
let state = OpenAIConversationState {
response_id: response_id.clone(),
input_items: combined_input,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64,
model: self.model.clone(),
provider: self.provider.clone(),
};
// Store asynchronously (fire and forget with logging)
let storage = self.storage.clone();
let response_id_clone = response_id.clone();
let request_id = self.request_id.clone();
tokio::spawn(async move {
match storage.put(state).await {
Ok(()) => {
debug!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}",
request_id,
response_id_clone
);
}
Err(e) => {
warn!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}",
request_id,
response_id_clone,
e
);
}
}
});
} else {
warn!(
"[PLANO_REQ_ID:{}] | BRIGHTSTAFF | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}",
self.request_id,
self.response_id.is_some(),
self.output_items.is_some()
);
}
}
fn on_error(&mut self, error: &str) {
self.inner.on_error(error);
}
}

View file

@ -0,0 +1,223 @@
use super::{OpenAIConversationState, StateStorage, StateStorageError};
use async_trait::async_trait;
use tracing::{debug, warn};
/// Supabase/PostgreSQL storage backend for conversation state
/// This is a placeholder implementation that can be extended with actual PostgreSQL logic
#[derive(Clone)]
pub struct SupabaseConversationalStorage {
// Connection pool or client would go here
// e.g., sqlx::PgPool or tokio_postgres::Client
_connection_string: String,
}
impl SupabaseConversationalStorage {
pub fn new(connection_string: String) -> Self {
Self {
_connection_string: connection_string,
}
}
}
#[async_trait]
impl StateStorage for SupabaseConversationalStorage {
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
warn!(
"Supabase storage not yet implemented - would store response_id: {}",
state.response_id
);
// TODO: Implement PostgreSQL storage
// SQL: INSERT INTO conversation_states (response_id, input_items, created_at, model, provider)
// VALUES ($1, $2, $3, $4, $5)
// ON CONFLICT (response_id) DO UPDATE SET ...
Err(StateStorageError::StorageError(
"Supabase storage not yet implemented".to_string(),
))
}
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
warn!(
"Supabase storage not yet implemented - would retrieve response_id: {}",
response_id
);
// TODO: Implement PostgreSQL retrieval
// SQL: SELECT * FROM conversation_states WHERE response_id = $1
Err(StateStorageError::StorageError(
"Supabase storage not yet implemented".to_string(),
))
}
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
debug!("Checking existence for response_id: {}", response_id);
// TODO: Implement PostgreSQL existence check
// SQL: SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)
Err(StateStorageError::StorageError(
"Supabase storage not yet implemented".to_string(),
))
}
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
debug!("Deleting response_id: {}", response_id);
// TODO: Implement PostgreSQL deletion
// SQL: DELETE FROM conversation_states WHERE response_id = $1
Err(StateStorageError::StorageError(
"Supabase storage not yet implemented".to_string(),
))
}
}
/*
Suggested PostgreSQL schema:
CREATE TABLE conversation_states (
response_id TEXT PRIMARY KEY,
input_items JSONB NOT NULL,
created_at BIGINT NOT NULL,
model TEXT NOT NULL,
provider TEXT NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_conversation_states_created_at ON conversation_states(created_at);
CREATE INDEX idx_conversation_states_provider ON conversation_states(provider);
*/
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent};
fn create_test_state(response_id: &str) -> OpenAIConversationState {
OpenAIConversationState {
response_id: response_id.to_string(),
input_items: vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Test message".to_string(),
}],
}),
],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
}
}
// These tests validate the current "not implemented" behavior
// Once the Supabase implementation is complete with actual PostgreSQL integration,
// these should be replaced with comprehensive tests similar to memory.rs
#[tokio::test]
async fn test_supabase_put_returns_not_implemented() {
let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string());
let state = create_test_state("resp_001");
let result = storage.put(state).await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::StorageError(msg) => {
assert!(msg.contains("not yet implemented"));
}
_ => panic!("Expected StorageError"),
}
}
#[tokio::test]
async fn test_supabase_get_returns_not_implemented() {
let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string());
let result = storage.get("resp_002").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::StorageError(msg) => {
assert!(msg.contains("not yet implemented"));
}
_ => panic!("Expected StorageError"),
}
}
#[tokio::test]
async fn test_supabase_exists_returns_not_implemented() {
let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string());
let result = storage.exists("resp_003").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::StorageError(msg) => {
assert!(msg.contains("not yet implemented"));
}
_ => panic!("Expected StorageError"),
}
}
#[tokio::test]
async fn test_supabase_delete_returns_not_implemented() {
let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string());
let result = storage.delete("resp_004").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::StorageError(msg) => {
assert!(msg.contains("not yet implemented"));
}
_ => panic!("Expected StorageError"),
}
}
#[tokio::test]
async fn test_supabase_merge_works() {
// merge() is implemented in the trait default, so it should work even without DB
let storage = SupabaseConversationalStorage::new("mock_connection_string".to_string());
let prev_state = create_test_state("resp_005");
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "New message".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Should have 2 messages (1 from prev + 1 current)
assert_eq!(merged.len(), 2);
}
/* TODO: Add comprehensive tests when SupabaseConversationalStorage is implemented
*
* Once the actual PostgreSQL integration is complete, add tests similar to those
* in memory.rs, including:
*
* - test_supabase_put_and_get_success: Store and retrieve state
* - test_supabase_put_overwrites_existing: Verify upsert behavior
* - test_supabase_get_not_found: Check NotFound error handling
* - test_supabase_exists_returns_false: Test non-existent ID
* - test_supabase_exists_returns_true_after_put: Verify existence after insert
* - test_supabase_delete_success: Delete and verify removal
* - test_supabase_delete_not_found: Delete non-existent ID
* - test_supabase_merge_various_scenarios: Test merge with different input combinations
* - test_supabase_concurrent_access: Test with multiple concurrent operations
* - test_supabase_serialization: Verify JSON serialization of input_items
* - test_supabase_connection_failure: Handle connection errors
* - test_supabase_invalid_data: Handle malformed JSON in database
*
* Test setup would require:
* - Test database setup/teardown (perhaps using testcontainers-rs or docker)
* - Connection pool initialization
* - Table creation before tests
* - Data cleanup between tests
*/
}

View file

@ -59,6 +59,11 @@ pub struct ResponsesAPIStreamBuffer {
model: Option<String>,
created_at: Option<i64>,
/// Full response metadata from upstream (tools, temperature, etc.)
/// This is extracted from the first upstream event and used to build
/// complete response.created and response.in_progress events
upstream_response_metadata: Option<ResponsesAPIResponse>,
/// Lifecycle state flags
created_emitted: bool,
in_progress_emitted: bool,
@ -88,6 +93,7 @@ impl ResponsesAPIStreamBuffer {
response_id: None,
model: None,
created_at: None,
upstream_response_metadata: None,
created_emitted: false,
in_progress_emitted: false,
output_items_added: HashMap::new(),
@ -171,6 +177,15 @@ impl ResponsesAPIStreamBuffer {
/// Build the base response object with current state
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
// If we have upstream metadata, use it as a base and update status/output
if let Some(upstream) = &self.upstream_response_metadata {
let mut response = upstream.clone();
response.status = status;
// Don't update output here - will be set in finalize()
return response;
}
// Fallback: build a minimal response from local state
ResponsesAPIResponse {
id: self.response_id.clone().unwrap_or_default(),
object: "response".to_string(),
@ -293,24 +308,40 @@ impl ResponsesAPIStreamBuffer {
// Build final response
let mut output_items = Vec::new();
// Add tool calls to output
for (item_id, arguments) in &self.function_arguments {
let output_index = self.output_items_added.iter()
.find(|(_, id)| *id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
// Build complete output array by iterating through all output indices in order
let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1);
let (call_id, name) = self.tool_call_metadata.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
for output_index in 0..=max_output_index {
if let Some(item_id) = self.output_items_added.get(&output_index) {
// Check if this is a function call
if let Some(arguments) = self.function_arguments.get(item_id) {
let (call_id, name) = self.tool_call_metadata.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
output_items.push(OutputItem::FunctionCall {
id: item_id.clone(),
status: OutputItemStatus::Completed,
call_id,
name: Some(name),
arguments: Some(arguments.clone()),
});
output_items.push(OutputItem::FunctionCall {
id: item_id.clone(),
status: OutputItemStatus::Completed,
call_id,
name: Some(name),
arguments: Some(arguments.clone()),
});
}
// Check if this is a text message
else if let Some(text) = self.text_content.get(item_id) {
use crate::apis::openai_responses::OutputContent;
output_items.push(OutputItem::Message {
id: item_id.clone(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: text.clone(),
annotations: vec![],
logprobs: None,
}],
});
}
}
}
let mut final_response = self.build_response(ResponseStatus::Completed);
@ -365,6 +396,24 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
match stream_event {
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
if self.upstream_response_metadata.is_none() {
// Store the full upstream response as our metadata template
self.upstream_response_metadata = Some(response.clone());
// Also extract basic fields
self.response_id = Some(response.id.clone());
self.model = Some(response.model.clone());
self.created_at = Some(response.created_at);
}
// Don't emit these - we'll generate our own lifecycle events
return;
}
_ => {}
}
// Emit lifecycle events if not yet emitted
if !self.created_emitted {
// Initialize metadata from first event if needed

View file

@ -193,6 +193,40 @@ impl SupportedAPIsFromClient {
}
}
impl SupportedUpstreamAPIs {
/// Create a SupportedUpstreamApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
// Check if this is the Responses API endpoint
if openai_api == OpenAIApi::Responses {
return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api));
}
// Otherwise it's ChatCompletions
return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api));
}
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api));
}
if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) {
match bedrock_api {
AmazonBedrockApi::Converse => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
}
AmazonBedrockApi::ConverseStream => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
}
}
}
None
}
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();

View file

@ -1,3 +1,4 @@
//! Response transformation modules
pub mod output_to_input;
pub mod to_anthropic;
pub mod to_openai;

View file

@ -0,0 +1,166 @@
//! Conversions from response outputs to request inputs for conversation continuation
//!
//! This module provides utilities for converting OutputItem types from API responses
//! into InputItem types that can be used in subsequent requests. This is primarily used
//! for maintaining conversation history in the v1/responses API.
use crate::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageRole, OutputContent, OutputItem,
};
/// Converts an OutputItem from a response into an InputItem for the next request
/// This is used to build conversation history from previous responses
pub fn output_item_to_input_item(output: &OutputItem) -> Option<InputItem> {
match output {
// Convert output messages to input messages
OutputItem::Message {
role, content, ..
} => {
let input_content: Vec<InputContent> = content
.iter()
.filter_map(|c| match c {
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
text: text.clone(),
}),
OutputContent::OutputAudio {
data, ..
} => Some(InputContent::InputAudio {
data: data.clone(),
format: None, // Format not preserved in output
}),
OutputContent::Refusal { .. } => None, // Skip refusals
})
.collect();
if input_content.is_empty() {
return None;
}
// Map role string to MessageRole enum
let message_role = match role.as_str() {
"user" => MessageRole::User,
"assistant" => MessageRole::Assistant,
"system" => MessageRole::System,
"developer" => MessageRole::Developer,
_ => MessageRole::Assistant, // Default to assistant
};
Some(InputItem::Message(InputMessage {
role: message_role,
content: input_content,
}))
}
// For function calls, we'll create an assistant message with the tool call info
// This matches how conversation history is typically built
OutputItem::FunctionCall {
name, arguments, ..
} => {
let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) {
format!("Called function: {} with arguments: {}", n, args)
} else {
"Called a function".to_string()
};
Some(InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: tool_call_text,
}],
}))
}
// Skip other output types (tool outputs, etc.) as they don't convert to input
_ => None,
}
}
/// Converts a Vec of OutputItems into InputItems for conversation continuation
pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
outputs
.iter()
.filter_map(output_item_to_input_item)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::openai_responses::{OutputItemStatus};
#[test]
fn test_output_message_to_input() {
let output = OutputItem::Message {
id: "msg_123".to_string(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: "Hello!".to_string(),
annotations: vec![],
logprobs: None,
}],
};
let input = output_item_to_input_item(&output).unwrap();
match input {
InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant));
assert_eq!(msg.content.len(), 1);
match &msg.content[0] {
InputContent::InputText { text } => assert_eq!(text, "Hello!"),
_ => panic!("Expected InputText"),
}
}
}
}
#[test]
fn test_function_call_to_input() {
let output = OutputItem::FunctionCall {
id: "fc_123".to_string(),
status: OutputItemStatus::Completed,
call_id: "call_123".to_string(),
name: Some("get_weather".to_string()),
arguments: Some(r#"{"location":"SF"}"#.to_string()),
};
let input = output_item_to_input_item(&output).unwrap();
match input {
InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant));
match &msg.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
}
}
}
}
#[test]
fn test_outputs_to_inputs() {
let outputs = vec![
OutputItem::Message {
id: "msg_1".to_string(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: "Hello".to_string(),
annotations: vec![],
logprobs: None,
}],
},
OutputItem::FunctionCall {
id: "fc_1".to_string(),
status: OutputItemStatus::Completed,
call_id: "call_1".to_string(),
name: Some("test".to_string()),
arguments: Some("{}".to_string()),
},
];
let inputs = outputs_to_inputs(&outputs);
assert_eq!(inputs.len(), 2);
}
}

View file

@ -80,8 +80,15 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
// Only add the message item if there's actual content (text, audio, or refusal)
// Don't add empty message items when there are only tool calls
if !content.is_empty() {
// Avoid double-prefixing: if ID already starts with "msg_", use as-is
let message_id = if resp.id.starts_with("msg_") {
resp.id.clone()
} else {
format!("msg_{}", resp.id)
};
items.push(OutputItem::Message {
id: format!("msg_{}", resp.id),
id: message_id,
status: OutputItemStatus::Completed,
role: match choice.message.role {
Role::User => "user".to_string(),
@ -151,7 +158,12 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
};
Ok(ResponsesAPIResponse {
id: resp.id,
// Generate proper resp_ prefixed ID if not already present
id: if resp.id.starts_with("resp_") {
resp.id
} else {
format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))
},
object: "response".to_string(),
created_at: resp.created as i64,
status,
@ -942,7 +954,7 @@ mod tests {
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
let chat_response = ChatCompletionsResponse {
id: "chatcmpl-123".to_string(),
id: "resp_6de5512800cf4375a329a473a4f02879".to_string(),
object: Some("chat.completion".to_string()),
created: 1677652288,
model: "gpt-4".to_string(),
@ -974,7 +986,9 @@ mod tests {
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
assert_eq!(responses_api.id, "chatcmpl-123");
// Response ID should be generated with resp_ prefix
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
assert_eq!(responses_api.object, "response");
assert_eq!(responses_api.model, "gpt-4");

View file

@ -58,11 +58,11 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
convert_content_block_start(content_block)
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
convert_content_block_start(content_block, index)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
@ -272,6 +272,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
/// Convert content block start to OpenAI chunk
fn convert_content_block_start(
content_block: MessagesContentBlock,
index: u32,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block {
MessagesContentBlock::Text { .. } => {
@ -291,7 +292,7 @@ fn convert_content_block_start(
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
index,
id: Some(id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
@ -313,6 +314,7 @@ fn convert_content_block_start(
/// Convert content delta to OpenAI chunk
fn convert_content_delta(
delta: MessagesContentDelta,
index: u32,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta {
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
@ -350,7 +352,7 @@ fn convert_content_delta(
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
index,
id: None,
call_type: None,
function: Some(FunctionCallDelta {

View file

@ -628,3 +628,204 @@ def test_openai_responses_api_streaming_with_tools_upstream_anthropic():
assert (
full_text or tool_calls
), "Expected streamed text or tool call argument deltas from Responses tools stream"
def test_conversation_state_management_two_turn():
"""
Test conversation state management across two turns:
1. Send initial message to non-OpenAI model via v1/responses
2. Capture response_id from first response
3. Send second message with previous_response_id
4. Verify model receives both messages in correct order
"""
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
logger.info("\n" + "=" * 80)
logger.info("TEST: Conversation State Management - Two Turn Flow")
logger.info("=" * 80)
# Turn 1: Send initial message to Anthropic (non-OpenAI model)
logger.info("\n[TURN 1] Sending initial message...")
resp1 = client.responses.create(
model="claude-sonnet-4-20250514",
input="My name is Alice and I like pizza.",
)
# Extract response_id from first response
response_id_1 = resp1.id
logger.info(f"[TURN 1] Received response_id: {response_id_1}")
logger.info(f"[TURN 1] Model response: {resp1.output_text}")
assert response_id_1 is not None, "First response should have an id"
assert len(resp1.output_text) > 0, "First response should have content"
# Turn 2: Send follow-up message with previous_response_id
# Ask the model to list all messages to verify state was combined
logger.info(
f"\n[TURN 2] Sending follow-up with previous_response_id={response_id_1}"
)
resp2 = client.responses.create(
model="claude-sonnet-4-20250514",
input="Please list all the messages you have received in our conversation, numbering each one.",
previous_response_id=response_id_1,
)
response_id_2 = resp2.id
logger.info(f"[TURN 2] Received response_id: {response_id_2}")
logger.info(f"[TURN 2] Model response: {resp2.output_text}")
assert response_id_2 is not None, "Second response should have an id"
assert response_id_2 != response_id_1, "Second response should have different id"
# Verify the model received the conversation history
# The response should reference both the initial message and the follow-up
response_lower = resp2.output_text.lower()
# Check if the model acknowledges receiving multiple messages
# Different models might format this differently, so we check for various indicators
has_conversation_context = (
"alice" in response_lower
or "pizza" in response_lower # References the name from turn 1
or "two" in response_lower # References the preference from turn 1
or "2" in response_lower # Mentions number of messages
or "first" in response_lower # Numeric indicator
or "second" # References first message
in response_lower # References second message
)
logger.info(
f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}"
)
logger.info(
f"[VALIDATION] Response contains conversation markers: {has_conversation_context}"
)
print(f"\n{'='*80}")
print("Conversation State Test Results:")
print(f"Turn 1 Response ID: {response_id_1}")
print(f"Turn 2 Response ID: {response_id_2}")
print(f"Turn 1 Output: {resp1.output_text[:100]}...")
print(f"Turn 2 Output: {resp2.output_text}")
print(f"Conversation Context Preserved: {has_conversation_context}")
print(f"{'='*80}\n")
assert has_conversation_context, (
f"Model should have received conversation history. "
f"Response: {resp2.output_text}"
)
def test_conversation_state_management_two_turn_streaming():
"""
Test conversation state management across two turns with streaming:
1. Send initial streaming message to non-OpenAI model via v1/responses
2. Capture response_id from first response
3. Send second streaming message with previous_response_id
4. Verify model receives both messages in correct order
"""
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
logger.info("\n" + "=" * 80)
logger.info("TEST: Conversation State Management - Two Turn Streaming Flow")
logger.info("=" * 80)
# Turn 1: Send initial streaming message to Anthropic (non-OpenAI model)
logger.info("\n[TURN 1] Sending initial streaming message...")
stream1 = client.responses.create(
model="claude-sonnet-4-20250514",
input="My name is Alice and I like pizza.",
stream=True,
)
# Collect streamed content and capture response_id
text_chunks_1 = []
response_id_1 = None
for event in stream1:
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
event, "delta", None
):
text_chunks_1.append(event.delta)
# Capture response_id from response.completed event
if getattr(event, "type", None) == "response.completed" and getattr(
event, "response", None
):
response_id_1 = event.response.id
output_1 = "".join(text_chunks_1)
logger.info(f"[TURN 1] Received response_id: {response_id_1}")
logger.info(f"[TURN 1] Model response: {output_1}")
assert response_id_1 is not None, "First response should have an id"
assert len(output_1) > 0, "First response should have content"
# Turn 2: Send follow-up streaming message with previous_response_id
logger.info(
f"\n[TURN 2] Sending follow-up streaming request with previous_response_id={response_id_1}"
)
stream2 = client.responses.create(
model="claude-sonnet-4-20250514",
input="Please list all the messages you have received in our conversation, numbering each one.",
previous_response_id=response_id_1,
stream=True,
)
# Collect streamed content from second response
text_chunks_2 = []
response_id_2 = None
for event in stream2:
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
event, "delta", None
):
text_chunks_2.append(event.delta)
# Capture response_id from response.completed event
if getattr(event, "type", None) == "response.completed" and getattr(
event, "response", None
):
response_id_2 = event.response.id
output_2 = "".join(text_chunks_2)
logger.info(f"[TURN 2] Received response_id: {response_id_2}")
logger.info(f"[TURN 2] Model response: {output_2}")
assert response_id_2 is not None, "Second response should have an id"
assert response_id_2 != response_id_1, "Second response should have different id"
# Verify the model received the conversation history
response_lower = output_2.lower()
# Check if the model acknowledges receiving multiple messages
has_conversation_context = (
"alice" in response_lower
or "pizza" in response_lower # References the name from turn 1
or "two" in response_lower # References the preference from turn 1
or "2" in response_lower # Mentions number of messages
or "first" in response_lower # Numeric indicator
or "second" # References first message
in response_lower # References second message
)
logger.info(
f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}"
)
logger.info(
f"[VALIDATION] Response contains conversation markers: {has_conversation_context}"
)
print(f"\n{'='*80}")
print("Streaming Conversation State Test Results:")
print(f"Turn 1 Response ID: {response_id_1}")
print(f"Turn 2 Response ID: {response_id_2}")
print(f"Turn 1 Output: {output_1[:100]}...")
print(f"Turn 2 Output: {output_2}")
print(f"Conversation Context Preserved: {has_conversation_context}")
print(f"{'='*80}\n")
assert has_conversation_context, (
f"Model should have received conversation history. " f"Response: {output_2}"
)