mirror of
https://github.com/katanemo/plano.git
synced 2026-05-09 07:42:43 +02:00
enable state management for v1/responses (#631)
* first commit with tests to enable state mamangement via memory * fixed logs to follow the conversational flow a bit better * added support for supabase * added the state_storage_v1_responses flag, and use that to store state appropriately * cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo * fixed mixed inputs from openai v1/responses api (#632) * fixed mixed inputs from openai v1/responses api * removing tracing from model-alias-rouing * handling additional input types from openairs --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local> * resolving PR comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
33e90dd338
commit
d5a273f740
26 changed files with 2687 additions and 76 deletions
611
crates/brightstaff/src/state/memory.rs
Normal file
611
crates/brightstaff/src/state/memory.rs
Normal file
|
|
@ -0,0 +1,611 @@
|
|||
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// In-memory storage backend for conversation state
|
||||
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryConversationalStorage {
|
||||
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
|
||||
}
|
||||
|
||||
impl MemoryConversationalStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
storage: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryConversationalStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StateStorage for MemoryConversationalStorage {
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||
let response_id = state.response_id.clone();
|
||||
let mut storage = self.storage.write().await;
|
||||
|
||||
debug!(
|
||||
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
|
||||
response_id, state.model, state.provider, state.input_items.len()
|
||||
);
|
||||
|
||||
storage.insert(response_id, state);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||
let storage = self.storage.read().await;
|
||||
|
||||
match storage.get(response_id) {
|
||||
Some(state) => {
|
||||
debug!(
|
||||
"[PLANO | MEMORY_STORAGE | RESP_ID:{} | Retrieved conversation state: input_items={}",
|
||||
response_id, state.input_items.len()
|
||||
);
|
||||
Ok(state.clone())
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"[PLANO_RESP_ID:{} | MEMORY_STORAGE | Conversation state not found",
|
||||
response_id
|
||||
);
|
||||
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||
let storage = self.storage.read().await;
|
||||
Ok(storage.contains_key(response_id))
|
||||
}
|
||||
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||
let mut storage = self.storage.write().await;
|
||||
|
||||
if storage.remove(response_id).is_some() {
|
||||
debug!(
|
||||
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
|
||||
response_id
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: format!("Message {}", i),
|
||||
}]),
|
||||
}));
|
||||
}
|
||||
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items,
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_and_get_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state: OpenAIConversationState = create_test_state("resp_001", 3);
|
||||
|
||||
// Store
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
// Retrieve
|
||||
let retrieved = storage.get("resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, state.response_id);
|
||||
assert_eq!(retrieved.model, state.model);
|
||||
assert_eq!(retrieved.provider, state.provider);
|
||||
assert_eq!(retrieved.input_items.len(), 3);
|
||||
assert_eq!(retrieved.created_at, state.created_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_overwrites_existing() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// First state
|
||||
let state1 = create_test_state("resp_002", 2);
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
// Overwrite with new state
|
||||
let state2 = OpenAIConversationState {
|
||||
response_id: "resp_002".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 9999999999,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
storage.put(state2.clone()).await.unwrap();
|
||||
|
||||
// Should retrieve the new state
|
||||
let retrieved = storage.get("resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
assert_eq!(retrieved.input_items.len(), 0);
|
||||
assert_eq!(retrieved.created_at, 9999999999);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.get("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_false_for_nonexistent() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
assert!(!storage.exists("resp_003").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_true_after_put() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_004", 1);
|
||||
|
||||
assert!(!storage.exists("resp_004").await.unwrap());
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_004").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_005", 2);
|
||||
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_005").await.unwrap());
|
||||
|
||||
// Delete
|
||||
storage.delete("resp_005").await.unwrap();
|
||||
|
||||
// Should no longer exist
|
||||
assert!(!storage.exists("resp_005").await.unwrap());
|
||||
assert!(storage.get("resp_005").await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.delete("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_combines_inputs() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Create a previous state with 2 messages
|
||||
let prev_state = create_test_state("resp_006", 2);
|
||||
|
||||
// Create current input with 1 message
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Merge
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 messages total (2 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_order() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has messages 0 and 1
|
||||
let prev_state = create_test_state("resp_007", 2);
|
||||
|
||||
// Current input has message 2
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Message 2".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Verify order: prev messages first, then current
|
||||
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_current_input() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_008", 3);
|
||||
|
||||
let merged = storage.merge(&prev_state, vec![]);
|
||||
|
||||
// Should just have the previous state's items
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_previous_state() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_009".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Only message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should just have the current input
|
||||
assert_eq!(merged.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_access() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Spawn multiple tasks that write concurrently
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let storage_clone = storage.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let state = create_test_state(&format!("resp_{}", i), i % 3);
|
||||
storage_clone.put(state).await.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// Verify all states were stored
|
||||
for i in 0..10 {
|
||||
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_operations_on_same_id() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_010", 1);
|
||||
|
||||
// Put
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
// Get
|
||||
let retrieved = storage.get("resp_010").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "resp_010");
|
||||
|
||||
// Exists
|
||||
assert!(storage.exists("resp_010").await.unwrap());
|
||||
|
||||
// Put again (overwrite)
|
||||
let new_state = create_test_state("resp_010", 5);
|
||||
storage.put(new_state).await.unwrap();
|
||||
|
||||
// Get updated
|
||||
let updated = storage.get("resp_010").await.unwrap();
|
||||
assert_eq!(updated.input_items.len(), 5);
|
||||
|
||||
// Delete
|
||||
storage.delete("resp_010").await.unwrap();
|
||||
|
||||
// Should not exist
|
||||
assert!(!storage.exists("resp_010").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_tool_call_flow() {
|
||||
// This test simulates a realistic tool call conversation flow:
|
||||
// 1. User sends message: "What's the weather?"
|
||||
// 2. Model responds with function call (converted to assistant message)
|
||||
// 3. User sends function call output in next request with previous_response_id
|
||||
// The merge should combine: user message + assistant function call + function output
|
||||
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Step 1: Previous state contains the initial exchange
|
||||
// - User message: "What's the weather in SF?"
|
||||
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_001".to_string(),
|
||||
input_items: vec![
|
||||
// Original user message
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather in San Francisco?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Assistant's function call (converted from OutputItem::FunctionCall)
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Step 2: Current request includes function call output
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Step 3: Merge should combine all conversation history
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 items: user question + assistant function call + function output
|
||||
assert_eq!(merged.len(), 3);
|
||||
|
||||
// Verify the order and content
|
||||
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("weather in San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("Function result"));
|
||||
assert!(text.contains("temperature"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_multiple_tool_calls() {
|
||||
// Test a more complex scenario with multiple tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has: user message + 2 function calls from assistant
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_002".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather and time in SF?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
// Current input: function outputs for both calls
|
||||
let current_input = vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather result: {\"temp\": 68}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Time result: {\"time\": \"14:30\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify first item is original user message
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
// Verify last two are function outputs
|
||||
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Weather result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Time result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_conversation_context_for_multi_turn() {
|
||||
// Simulate a multi-turn conversation with tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state: full conversation history up to this point
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_003".to_string(),
|
||||
input_items: vec![
|
||||
// Turn 1: User asks about weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 1: Assistant calls get_weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: User provides function output
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather: sunny, 72°F".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: Assistant responds with text
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "It's sunny and 72°F in San Francisco today!".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Turn 3: User asks follow-up question
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Should I bring an umbrella?".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have all 5 messages in order
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify the entire conversation flow is preserved
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
}
|
||||
147
crates/brightstaff/src/state/mod.rs
Normal file
147
crates/brightstaff/src/state/mod.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
use async_trait::async_trait;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug};
|
||||
|
||||
pub mod memory;
|
||||
pub mod response_state_processor;
|
||||
pub mod postgresql;
|
||||
|
||||
/// Represents the conversational state for a v1/responses request
|
||||
/// Contains the complete input/output history that can be restored
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIConversationState {
|
||||
/// The response ID this state is associated with
|
||||
pub response_id: String,
|
||||
|
||||
/// The complete input history (original input + accumulated outputs)
|
||||
/// This is what gets prepended to new requests via previous_response_id
|
||||
pub input_items: Vec<InputItem>,
|
||||
|
||||
/// Timestamp when this state was created
|
||||
pub created_at: i64,
|
||||
|
||||
/// Model used for this response
|
||||
pub model: String,
|
||||
|
||||
/// Provider that generated this response (e.g., "anthropic", "openai")
|
||||
pub provider: String,
|
||||
}
|
||||
|
||||
/// Error types for state storage operations
|
||||
#[derive(Debug)]
|
||||
pub enum StateStorageError {
|
||||
/// State not found for given response_id
|
||||
NotFound(String),
|
||||
|
||||
/// Storage backend error (network, database, etc.)
|
||||
StorageError(String),
|
||||
|
||||
/// Serialization/deserialization error
|
||||
SerializationError(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for StateStorageError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
|
||||
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
|
||||
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for StateStorageError {}
|
||||
|
||||
/// Trait for conversation state storage backends
|
||||
#[async_trait]
|
||||
pub trait StateStorage: Send + Sync {
|
||||
/// Store conversation state for a response
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>;
|
||||
|
||||
/// Retrieve conversation state by response_id
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError>;
|
||||
|
||||
/// Check if state exists for a response_id
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError>;
|
||||
|
||||
/// Delete state for a response_id (optional, for cleanup)
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
|
||||
|
||||
fn merge(
|
||||
&self,
|
||||
prev_state: &OpenAIConversationState,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Vec<InputItem> {
|
||||
// Default implementation: prepend previous input, append current
|
||||
let prev_count = prev_state.input_items.len();
|
||||
let current_count = current_input.len();
|
||||
|
||||
let mut combined_input = prev_state.input_items.clone();
|
||||
combined_input.extend(current_input);
|
||||
|
||||
debug!(
|
||||
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
|
||||
prev_state.response_id,
|
||||
prev_count,
|
||||
current_count,
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
);
|
||||
|
||||
combined_input
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Storage backend type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageBackend {
|
||||
Memory,
|
||||
Supabase,
|
||||
}
|
||||
|
||||
impl StorageBackend {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"memory" => Some(StorageBackend::Memory),
|
||||
"supabase" => Some(StorageBackend::Supabase),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Utility functions for state management ===
|
||||
|
||||
/// Extract input items from InputParam, converting text to structured format
|
||||
pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
|
||||
match input {
|
||||
InputParam::Text(text) => {
|
||||
vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: text.clone(),
|
||||
}]),
|
||||
})]
|
||||
}
|
||||
InputParam::Items(items) => items.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve previous conversation state and combine with current input
|
||||
/// Returns combined input if previous state found, or original input if not found/error
|
||||
pub async fn retrieve_and_combine_input(
|
||||
storage: Arc<dyn StateStorage>,
|
||||
previous_response_id: &str,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Result<Vec<InputItem>, StateStorageError> {
|
||||
|
||||
// First get the previous state
|
||||
let prev_state = storage.get(previous_response_id).await?;
|
||||
let combined_input = storage.merge(&prev_state, current_input);
|
||||
Ok(combined_input)
|
||||
}
|
||||
432
crates/brightstaff/src/state/postgresql.rs
Normal file
432
crates/brightstaff/src/state/postgresql.rs
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||
use async_trait::async_trait;
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio_postgres::{Client, NoTls};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Supabase/PostgreSQL storage backend for conversation state
|
||||
#[derive(Clone)]
|
||||
pub struct PostgreSQLConversationStorage {
|
||||
client: Arc<Client>,
|
||||
table_verified: Arc<OnceCell<()>>,
|
||||
}
|
||||
|
||||
impl PostgreSQLConversationStorage {
|
||||
/// Creates a new Supabase storage instance with the given connection string
|
||||
pub async fn new(connection_string: String) -> Result<Self, StateStorageError> {
|
||||
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!("Failed to connect to database: {}", e))
|
||||
})?;
|
||||
|
||||
// Spawn the connection to run in the background
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
warn!("Database connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(client),
|
||||
table_verified: Arc::new(OnceCell::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Ensures the conversation_states table exists (checks once, caches result)
|
||||
async fn ensure_ready(&self) -> Result<(), StateStorageError> {
|
||||
self.table_verified
|
||||
.get_or_try_init(|| async {
|
||||
let row = self
|
||||
.client
|
||||
.query_one(
|
||||
"SELECT EXISTS (
|
||||
SELECT FROM pg_tables
|
||||
WHERE tablename = 'conversation_states'
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to verify table existence: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let exists: bool = row.get(0);
|
||||
|
||||
if !exists {
|
||||
return Err(StateStorageError::StorageError(
|
||||
"Table 'conversation_states' does not exist. \
|
||||
Please run the setup SQL from docs/db_setup/conversation_states.sql"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
info!("Conversation state storage table verified");
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StateStorage for PostgreSQLConversationStorage {
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
// Serialize input_items to JSONB
|
||||
let input_items_json = serde_json::to_value(&state.input_items).map_err(|e| {
|
||||
StateStorageError::StorageError(format!("Failed to serialize input_items: {}", e))
|
||||
})?;
|
||||
|
||||
// Upsert the conversation state
|
||||
self.client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO conversation_states
|
||||
(response_id, input_items, created_at, model, provider, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
ON CONFLICT (response_id)
|
||||
DO UPDATE SET
|
||||
input_items = EXCLUDED.input_items,
|
||||
model = EXCLUDED.model,
|
||||
provider = EXCLUDED.provider,
|
||||
updated_at = NOW()
|
||||
"#,
|
||||
&[
|
||||
&state.response_id,
|
||||
&input_items_json,
|
||||
&state.created_at,
|
||||
&state.model,
|
||||
&state.provider,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to store conversation state for {}: {}",
|
||||
state.response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
debug!("Stored conversation state for {}", state.response_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let row = self
|
||||
.client
|
||||
.query_opt(
|
||||
r#"
|
||||
SELECT response_id, input_items, created_at, model, provider
|
||||
FROM conversation_states
|
||||
WHERE response_id = $1
|
||||
"#,
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to fetch conversation state for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
match row {
|
||||
Some(row) => {
|
||||
let response_id: String = row.get("response_id");
|
||||
let input_items_json: serde_json::Value = row.get("input_items");
|
||||
let created_at: i64 = row.get("created_at");
|
||||
let model: String = row.get("model");
|
||||
let provider: String = row.get("provider");
|
||||
|
||||
// Deserialize input_items from JSONB
|
||||
let input_items =
|
||||
serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(OpenAIConversationState {
|
||||
response_id,
|
||||
input_items,
|
||||
created_at,
|
||||
model,
|
||||
provider,
|
||||
})
|
||||
}
|
||||
None => Err(StateStorageError::NotFound(format!(
|
||||
"Conversation state not found for response_id: {}",
|
||||
response_id
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let row = self
|
||||
.client
|
||||
.query_one(
|
||||
"SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)",
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to check existence for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let exists: bool = row.get(0);
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let rows_affected = self
|
||||
.client
|
||||
.execute(
|
||||
"DELETE FROM conversation_states WHERE response_id = $1",
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to delete conversation state for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
if rows_affected == 0 {
|
||||
return Err(StateStorageError::NotFound(format!(
|
||||
"Conversation state not found for response_id: {}",
|
||||
response_id
|
||||
)));
|
||||
}
|
||||
|
||||
debug!("Deleted conversation state for {}", response_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
PostgreSQL schema is maintained in docs/db_setup/conversation_states.sql
|
||||
Run that SQL file against your database before using this storage backend.
|
||||
*/
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
|
||||
|
||||
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items: vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Test message".to_string(),
|
||||
}]),
|
||||
})],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// Note: These tests require a running PostgreSQL database
|
||||
// Set TEST_DATABASE_URL environment variable to run integration tests
|
||||
// Example: TEST_DATABASE_URL=postgresql://user:pass@localhost/test_db
|
||||
|
||||
async fn get_test_storage() -> Option<PostgreSQLConversationStorage> {
|
||||
if let Ok(db_url) = std::env::var("TEST_DATABASE_URL") {
|
||||
match PostgreSQLConversationStorage::new(db_url).await {
|
||||
Ok(storage) => Some(storage),
|
||||
Err(e) => {
|
||||
eprintln!("Failed to create test storage: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("TEST_DATABASE_URL not set, skipping Supabase integration tests");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_and_get_success() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_001");
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "test_resp_001");
|
||||
assert_eq!(retrieved.input_items.len(), 1);
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_001").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_overwrites_existing() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state1 = create_test_state("test_resp_002");
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
let mut state2 = create_test_state("test_resp_002");
|
||||
state2.model = "gpt-4-turbo".to_string();
|
||||
state2.input_items.push(InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Response".to_string(),
|
||||
}]),
|
||||
}));
|
||||
storage.put(state2).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4-turbo");
|
||||
assert_eq!(retrieved.input_items.len(), 2);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_002").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_get_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.get("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_false() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let exists = storage.exists("nonexistent_id").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_true_after_put() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_003");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_003").await.unwrap();
|
||||
assert!(exists);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_003").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_success() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_004");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
storage.delete("test_resp_004").await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_004").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.delete("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_merge_works() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prev_state = create_test_state("test_resp_005");
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 2 messages (1 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_table_verification() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// This should trigger table verification
|
||||
let result = storage.ensure_ready().await;
|
||||
assert!(result.is_ok(), "Table verification should succeed");
|
||||
|
||||
// Second call should use cached result
|
||||
let result2 = storage.ensure_ready().await;
|
||||
assert!(result2.is_ok(), "Cached verification should succeed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Run manually with: cargo test test_verify_data_in_supabase -- --ignored
|
||||
async fn test_verify_data_in_supabase() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Create a test record that persists
|
||||
let state = create_test_state("manual_test_verification");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
println!("✅ Data written to Supabase!");
|
||||
println!("Check your Supabase dashboard:");
|
||||
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!("\nTo cleanup, run:");
|
||||
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
|
||||
// DON'T cleanup - leave it for manual verification
|
||||
}
|
||||
}
|
||||
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
use bytes::Bytes;
|
||||
use flate2::read::GzDecoder;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputItem, OutputItem, ResponsesAPIStreamEvent,
|
||||
};
|
||||
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, debug, warn};
|
||||
|
||||
use crate::handlers::utils::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
|
||||
/// Processor that wraps another processor and handles v1/responses state management
|
||||
/// Captures response_id and output from streaming responses, stores state after completion
|
||||
pub struct ResponsesStateProcessor<P: StreamProcessor> {
|
||||
/// The underlying processor (e.g., ObservableStreamProcessor for metrics)
|
||||
inner: P,
|
||||
|
||||
/// State storage backend
|
||||
storage: Arc<dyn StateStorage>,
|
||||
|
||||
/// Original input items from the request
|
||||
original_input: Vec<InputItem>,
|
||||
|
||||
/// Model name
|
||||
model: String,
|
||||
|
||||
/// Provider name
|
||||
provider: String,
|
||||
|
||||
/// Whether this is a streaming request
|
||||
is_streaming: bool,
|
||||
|
||||
/// Whether upstream is OpenAI (skip storage if true)
|
||||
is_openai_upstream: bool,
|
||||
|
||||
/// Content-Encoding header value (e.g., "gzip", "br", None)
|
||||
content_encoding: Option<String>,
|
||||
|
||||
/// Request ID for logging
|
||||
request_id: String,
|
||||
|
||||
/// Buffer for accumulating chunks (needed for non-streaming compressed responses)
|
||||
chunk_buffer: Vec<u8>,
|
||||
|
||||
/// Captured response_id from response.completed event
|
||||
response_id: Option<String>,
|
||||
|
||||
/// Captured output items from response.completed event
|
||||
output_items: Option<Vec<OutputItem>>,
|
||||
}
|
||||
|
||||
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
||||
pub fn new(
|
||||
inner: P,
|
||||
storage: Arc<dyn StateStorage>,
|
||||
original_input: Vec<InputItem>,
|
||||
model: String,
|
||||
provider: String,
|
||||
is_streaming: bool,
|
||||
is_openai_upstream: bool,
|
||||
content_encoding: Option<String>,
|
||||
request_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
storage,
|
||||
original_input,
|
||||
model,
|
||||
provider,
|
||||
is_streaming,
|
||||
is_openai_upstream,
|
||||
content_encoding,
|
||||
request_id,
|
||||
chunk_buffer: Vec::new(),
|
||||
response_id: None,
|
||||
output_items: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decompress accumulated buffer based on Content-Encoding header
|
||||
fn decompress_buffer(&self) -> Vec<u8> {
|
||||
if self.chunk_buffer.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
match self.content_encoding.as_deref() {
|
||||
Some("gzip") => {
|
||||
let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice());
|
||||
let mut decompressed = Vec::new();
|
||||
match decoder.read_to_end(&mut decompressed) {
|
||||
Ok(_) => {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes",
|
||||
self.request_id,
|
||||
self.chunk_buffer.len(),
|
||||
decompressed.len()
|
||||
);
|
||||
decompressed
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to decompress gzip buffer: {}",
|
||||
self.request_id,
|
||||
e
|
||||
);
|
||||
self.chunk_buffer.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(encoding) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.",
|
||||
self.request_id,
|
||||
encoding
|
||||
);
|
||||
self.chunk_buffer.clone()
|
||||
}
|
||||
None => self.chunk_buffer.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse response to extract response_id and output
|
||||
/// For streaming: parse SSE events looking for response.completed (per chunk)
|
||||
/// For non-streaming: buffer all chunks, then decompress and parse on completion
|
||||
fn try_parse_response_chunk(&mut self, chunk: &[u8]) {
|
||||
if self.is_streaming {
|
||||
// Streaming: Try to parse SSE events from this chunk
|
||||
// Note: For compressed streaming, we'd need to buffer and decompress first
|
||||
// but most streaming responses aren't compressed since SSE needs to be readable
|
||||
let sse_iter = match SseStreamIter::try_from(chunk) {
|
||||
Ok(iter) => iter,
|
||||
Err(_) => return, // Not valid SSE format, skip
|
||||
};
|
||||
|
||||
// Process each SSE event in the chunk, looking for data lines with response.completed
|
||||
for event in sse_iter {
|
||||
// Only process data lines (skip event-only lines)
|
||||
if let Some(data_str) = &event.data {
|
||||
// Try to parse as ResponsesAPIStreamEvent
|
||||
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
|
||||
// Check if this is a ResponseCompleted event
|
||||
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: Buffer chunks, will decompress and parse on completion
|
||||
self.chunk_buffer.extend_from_slice(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse buffered non-streaming response (called on completion)
|
||||
fn try_parse_buffered_response(&mut self) {
|
||||
if self.is_streaming || self.chunk_buffer.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
let decompressed = self.decompress_buffer();
|
||||
|
||||
// Parse complete JSON response
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
|
||||
Ok(response) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
}
|
||||
Err(e) => {
|
||||
// Log parse error with chunk preview for debugging
|
||||
let chunk_preview = String::from_utf8_lossy(&decompressed);
|
||||
let preview_len = chunk_preview.len().min(200);
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}",
|
||||
self.request_id,
|
||||
e,
|
||||
preview_len,
|
||||
&chunk_preview[..preview_len]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: StreamProcessor> StreamProcessor for ResponsesStateProcessor<P> {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
// Buffer/parse chunk for response extraction
|
||||
self.try_parse_response_chunk(&chunk);
|
||||
|
||||
// Forward to inner processor
|
||||
self.inner.process_chunk(chunk)
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
self.inner.on_first_bytes();
|
||||
}
|
||||
|
||||
fn on_complete(&mut self) {
|
||||
// For non-streaming, decompress and parse buffered response
|
||||
self.try_parse_buffered_response();
|
||||
|
||||
// First, let the inner processor complete
|
||||
self.inner.on_complete();
|
||||
|
||||
// Skip storage for OpenAI upstream
|
||||
if self.is_openai_upstream {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider",
|
||||
self.request_id
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store state if we captured response_id and output
|
||||
if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) {
|
||||
// Convert output items to input items for next request
|
||||
let output_as_inputs = outputs_to_inputs(output_items);
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}",
|
||||
self.request_id, output_items.len(), output_as_inputs.len()
|
||||
);
|
||||
|
||||
// Combine original input + output as new input history
|
||||
let mut combined_input = self.original_input.clone();
|
||||
combined_input.extend(output_as_inputs);
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}",
|
||||
self.request_id,
|
||||
self.original_input.len(),
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
);
|
||||
|
||||
let state = OpenAIConversationState {
|
||||
response_id: response_id.clone(),
|
||||
input_items: combined_input,
|
||||
created_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i64,
|
||||
model: self.model.clone(),
|
||||
provider: self.provider.clone(),
|
||||
};
|
||||
|
||||
// Store asynchronously (fire and forget with logging)
|
||||
let storage = self.storage.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
let request_id = self.request_id.clone();
|
||||
let items_count = state.input_items.len();
|
||||
tokio::spawn(async move {
|
||||
match storage.put(state).await {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}, items_count={}",
|
||||
request_id,
|
||||
response_id_clone,
|
||||
items_count
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}",
|
||||
request_id,
|
||||
response_id_clone,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}",
|
||||
self.request_id,
|
||||
self.response_id.is_some(),
|
||||
self.output_items.is_some()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error: &str) {
|
||||
self.inner.on_error(error);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue