plano/crates/brightstaff/src/state/memory.rs

584 lines
21 KiB
Rust

use super::{OpenAIConversationState, StateStorage, StateStorageError};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
/// In-memory storage backend for conversation state
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
#[derive(Clone)]
pub struct MemoryConversationalStorage {
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
}
impl MemoryConversationalStorage {
pub fn new() -> Self {
Self {
storage: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemoryConversationalStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStorage for MemoryConversationalStorage {
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
let response_id = state.response_id.clone();
let mut storage = self.storage.write().await;
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
response_id, state.model, state.provider, state.input_items.len()
);
storage.insert(response_id, state);
Ok(())
}
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
let storage = self.storage.read().await;
match storage.get(response_id) {
Some(state) => {
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Retrieved conversation state: input_items={}",
response_id, state.input_items.len()
);
Ok(state.clone())
}
None => {
warn!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Conversation state not found",
response_id
);
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
let storage = self.storage.read().await;
Ok(storage.contains_key(response_id))
}
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
let mut storage = self.storage.write().await;
if storage.remove(response_id).is_some() {
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
response_id
);
Ok(())
} else {
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new();
for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
content: vec![InputContent::InputText {
text: format!("Message {}", i),
}],
}));
}
OpenAIConversationState {
response_id: response_id.to_string(),
input_items,
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
}
}
#[tokio::test]
async fn test_put_and_get_success() {
let storage = MemoryConversationalStorage::new();
let state: OpenAIConversationState = create_test_state("resp_001", 3);
// Store
storage.put(state.clone()).await.unwrap();
// Retrieve
let retrieved = storage.get("resp_001").await.unwrap();
assert_eq!(retrieved.response_id, state.response_id);
assert_eq!(retrieved.model, state.model);
assert_eq!(retrieved.provider, state.provider);
assert_eq!(retrieved.input_items.len(), 3);
assert_eq!(retrieved.created_at, state.created_at);
}
#[tokio::test]
async fn test_put_overwrites_existing() {
let storage = MemoryConversationalStorage::new();
// First state
let state1 = create_test_state("resp_002", 2);
storage.put(state1).await.unwrap();
// Overwrite with new state
let state2 = OpenAIConversationState {
response_id: "resp_002".to_string(),
input_items: vec![],
created_at: 9999999999,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
storage.put(state2.clone()).await.unwrap();
// Should retrieve the new state
let retrieved = storage.get("resp_002").await.unwrap();
assert_eq!(retrieved.model, "gpt-4");
assert_eq!(retrieved.provider, "openai");
assert_eq!(retrieved.input_items.len(), 0);
assert_eq!(retrieved.created_at, 9999999999);
}
#[tokio::test]
async fn test_get_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.get("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_exists_returns_false_for_nonexistent() {
let storage = MemoryConversationalStorage::new();
assert!(!storage.exists("resp_003").await.unwrap());
}
#[tokio::test]
async fn test_exists_returns_true_after_put() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_004", 1);
assert!(!storage.exists("resp_004").await.unwrap());
storage.put(state).await.unwrap();
assert!(storage.exists("resp_004").await.unwrap());
}
#[tokio::test]
async fn test_delete_success() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_005", 2);
storage.put(state).await.unwrap();
assert!(storage.exists("resp_005").await.unwrap());
// Delete
storage.delete("resp_005").await.unwrap();
// Should no longer exist
assert!(!storage.exists("resp_005").await.unwrap());
assert!(storage.get("resp_005").await.is_err());
}
#[tokio::test]
async fn test_delete_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.delete("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_merge_combines_inputs() {
let storage = MemoryConversationalStorage::new();
// Create a previous state with 2 messages
let prev_state = create_test_state("resp_006", 2);
// Create current input with 1 message
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "New message".to_string(),
}],
})];
// Merge
let merged = storage.merge(&prev_state, current_input);
// Should have 3 messages total (2 from prev + 1 current)
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_preserves_order() {
let storage = MemoryConversationalStorage::new();
// Previous state has messages 0 and 1
let prev_state = create_test_state("resp_007", 2);
// Current input has message 2
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Message 2".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0];
match &msg.content[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg) = &merged[2];
match &msg.content[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_with_empty_current_input() {
let storage = MemoryConversationalStorage::new();
let prev_state = create_test_state("resp_008", 3);
let merged = storage.merge(&prev_state, vec![]);
// Should just have the previous state's items
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_with_empty_previous_state() {
let storage = MemoryConversationalStorage::new();
let prev_state = OpenAIConversationState {
response_id: "resp_009".to_string(),
input_items: vec![],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Only message".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Should just have the current input
assert_eq!(merged.len(), 1);
}
#[tokio::test]
async fn test_concurrent_access() {
let storage = MemoryConversationalStorage::new();
// Spawn multiple tasks that write concurrently
let mut handles = vec![];
for i in 0..10 {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
let state = create_test_state(&format!("resp_{}", i), i % 3);
storage_clone.put(state).await.unwrap();
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Verify all states were stored
for i in 0..10 {
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
}
}
#[tokio::test]
async fn test_multiple_operations_on_same_id() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_010", 1);
// Put
storage.put(state.clone()).await.unwrap();
// Get
let retrieved = storage.get("resp_010").await.unwrap();
assert_eq!(retrieved.response_id, "resp_010");
// Exists
assert!(storage.exists("resp_010").await.unwrap());
// Put again (overwrite)
let new_state = create_test_state("resp_010", 5);
storage.put(new_state).await.unwrap();
// Get updated
let updated = storage.get("resp_010").await.unwrap();
assert_eq!(updated.input_items.len(), 5);
// Delete
storage.delete("resp_010").await.unwrap();
// Should not exist
assert!(!storage.exists("resp_010").await.unwrap());
}
#[tokio::test]
async fn test_merge_with_tool_call_flow() {
// This test simulates a realistic tool call conversation flow:
// 1. User sends message: "What's the weather?"
// 2. Model responds with function call (converted to assistant message)
// 3. User sends function call output in next request with previous_response_id
// The merge should combine: user message + assistant function call + function output
let storage = MemoryConversationalStorage::new();
// Step 1: Previous state contains the initial exchange
// - User message: "What's the weather in SF?"
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
let prev_state = OpenAIConversationState {
response_id: "resp_tool_001".to_string(),
input_items: vec![
// Original user message
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather in San Francisco?".to_string(),
}],
}),
// Assistant's function call (converted from OutputItem::FunctionCall)
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
}],
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Step 2: Current request includes function call output
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
}],
})];
// Step 3: Merge should combine all conversation history
let merged = storage.merge(&prev_state, current_input);
// Should have 3 items: user question + assistant function call + function output
assert_eq!(merged.len(), 3);
// Verify the order and content
let InputItem::Message(msg1) = &merged[0];
assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("weather in San Francisco"));
}
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg2) = &merged[1];
assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
}
let InputItem::Message(msg3) = &merged[2];
assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content[0] {
InputContent::InputText { text } => {
assert!(text.contains("Function result"));
assert!(text.contains("temperature"));
}
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_with_multiple_tool_calls() {
// Test a more complex scenario with multiple tool calls
let storage = MemoryConversationalStorage::new();
// Previous state has: user message + 2 function calls from assistant
let prev_state = OpenAIConversationState {
response_id: "resp_tool_002".to_string(),
input_items: vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather and time in SF?".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
}],
}),
],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
// Current input: function outputs for both calls
let current_input = vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Weather result: {\"temp\": 68}".to_string(),
}],
}),
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Time result: {\"time\": \"14:30\"}".to_string(),
}],
}),
];
let merged = storage.merge(&prev_state, current_input);
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
assert_eq!(merged.len(), 5);
// Verify first item is original user message
let InputItem::Message(first) = &merged[0];
assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3];
assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content[0] {
InputContent::InputText { text } => assert!(text.contains("Weather result")),
_ => panic!("Expected InputText"),
}
let InputItem::Message(last) = &merged[4];
assert!(matches!(last.role, MessageRole::User));
match &last.content[0] {
InputContent::InputText { text } => assert!(text.contains("Time result")),
_ => panic!("Expected InputText"),
}
}
#[tokio::test]
async fn test_merge_preserves_conversation_context_for_multi_turn() {
// Simulate a multi-turn conversation with tool calls
let storage = MemoryConversationalStorage::new();
// Previous state: full conversation history up to this point
let prev_state = OpenAIConversationState {
response_id: "resp_tool_003".to_string(),
input_items: vec![
// Turn 1: User asks about weather
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "What's the weather?".to_string(),
}],
}),
// Turn 1: Assistant calls get_weather
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "Called function: get_weather".to_string(),
}],
}),
// Turn 2: User provides function output
InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Weather: sunny, 72°F".to_string(),
}],
}),
// Turn 2: Assistant responds with text
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: vec![InputContent::InputText {
text: "It's sunny and 72°F in San Francisco today!".to_string(),
}],
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Turn 3: User asks follow-up question
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: vec![InputContent::InputText {
text: "Should I bring an umbrella?".to_string(),
}],
})];
let merged = storage.merge(&prev_state, current_input);
// Should have all 5 messages in order
assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0];
match &first.content[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
_ => panic!("Expected InputText"),
}
let InputItem::Message(last) = &merged[4];
match &last.content[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")),
_ => panic!("Expected InputText"),
}
}
}