mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
596 lines
25 KiB
Python
596 lines
25 KiB
Python
|
|
"""
|
||
|
|
Unit tests for conversation state management
|
||
|
|
|
||
|
|
Tests the core business logic for managing conversation state,
|
||
|
|
including history tracking, context preservation, and multi-turn
|
||
|
|
reasoning support.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import Mock
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
import json
|
||
|
|
|
||
|
|
|
||
|
|
class TestConversationStateLogic:
|
||
|
|
"""Test cases for conversation state management business logic"""
|
||
|
|
|
||
|
|
def test_conversation_initialization(self):
|
||
|
|
"""Test initialization of new conversation state"""
|
||
|
|
# Arrange
|
||
|
|
class ConversationState:
|
||
|
|
def __init__(self, conversation_id=None, user_id=None):
|
||
|
|
self.conversation_id = conversation_id or f"conv_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
|
|
self.user_id = user_id
|
||
|
|
self.created_at = datetime.now()
|
||
|
|
self.updated_at = datetime.now()
|
||
|
|
self.turns = []
|
||
|
|
self.context = {}
|
||
|
|
self.metadata = {}
|
||
|
|
self.is_active = True
|
||
|
|
|
||
|
|
def to_dict(self):
|
||
|
|
return {
|
||
|
|
"conversation_id": self.conversation_id,
|
||
|
|
"user_id": self.user_id,
|
||
|
|
"created_at": self.created_at.isoformat(),
|
||
|
|
"updated_at": self.updated_at.isoformat(),
|
||
|
|
"turns": self.turns,
|
||
|
|
"context": self.context,
|
||
|
|
"metadata": self.metadata,
|
||
|
|
"is_active": self.is_active
|
||
|
|
}
|
||
|
|
|
||
|
|
# Act
|
||
|
|
conv1 = ConversationState(user_id="user123")
|
||
|
|
conv2 = ConversationState(conversation_id="custom_conv_id", user_id="user456")
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert conv1.conversation_id.startswith("conv_")
|
||
|
|
assert conv1.user_id == "user123"
|
||
|
|
assert conv1.is_active is True
|
||
|
|
assert len(conv1.turns) == 0
|
||
|
|
assert isinstance(conv1.created_at, datetime)
|
||
|
|
|
||
|
|
assert conv2.conversation_id == "custom_conv_id"
|
||
|
|
assert conv2.user_id == "user456"
|
||
|
|
|
||
|
|
# Test serialization
|
||
|
|
conv_dict = conv1.to_dict()
|
||
|
|
assert "conversation_id" in conv_dict
|
||
|
|
assert "created_at" in conv_dict
|
||
|
|
assert isinstance(conv_dict["turns"], list)
|
||
|
|
|
||
|
|
def test_turn_management(self):
|
||
|
|
"""Test adding and managing conversation turns"""
|
||
|
|
# Arrange
|
||
|
|
class ConversationState:
|
||
|
|
def __init__(self, conversation_id=None, user_id=None):
|
||
|
|
self.conversation_id = conversation_id or f"conv_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
|
|
self.user_id = user_id
|
||
|
|
self.created_at = datetime.now()
|
||
|
|
self.updated_at = datetime.now()
|
||
|
|
self.turns = []
|
||
|
|
self.context = {}
|
||
|
|
self.metadata = {}
|
||
|
|
self.is_active = True
|
||
|
|
|
||
|
|
def to_dict(self):
|
||
|
|
return {
|
||
|
|
"conversation_id": self.conversation_id,
|
||
|
|
"user_id": self.user_id,
|
||
|
|
"created_at": self.created_at.isoformat(),
|
||
|
|
"updated_at": self.updated_at.isoformat(),
|
||
|
|
"turns": self.turns,
|
||
|
|
"context": self.context,
|
||
|
|
"metadata": self.metadata,
|
||
|
|
"is_active": self.is_active
|
||
|
|
}
|
||
|
|
|
||
|
|
class ConversationTurn:
|
||
|
|
def __init__(self, role, content, timestamp=None, metadata=None):
|
||
|
|
self.role = role # "user" or "assistant"
|
||
|
|
self.content = content
|
||
|
|
self.timestamp = timestamp or datetime.now()
|
||
|
|
self.metadata = metadata or {}
|
||
|
|
|
||
|
|
def to_dict(self):
|
||
|
|
return {
|
||
|
|
"role": self.role,
|
||
|
|
"content": self.content,
|
||
|
|
"timestamp": self.timestamp.isoformat(),
|
||
|
|
"metadata": self.metadata
|
||
|
|
}
|
||
|
|
|
||
|
|
class ConversationManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.conversations = {}
|
||
|
|
|
||
|
|
def add_turn(self, conversation_id, role, content, metadata=None):
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
return False, "Conversation not found"
|
||
|
|
|
||
|
|
turn = ConversationTurn(role, content, metadata=metadata)
|
||
|
|
self.conversations[conversation_id].turns.append(turn)
|
||
|
|
self.conversations[conversation_id].updated_at = datetime.now()
|
||
|
|
|
||
|
|
return True, turn
|
||
|
|
|
||
|
|
def get_recent_turns(self, conversation_id, limit=10):
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
return []
|
||
|
|
|
||
|
|
turns = self.conversations[conversation_id].turns
|
||
|
|
return turns[-limit:] if len(turns) > limit else turns
|
||
|
|
|
||
|
|
def get_turn_count(self, conversation_id):
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
return 0
|
||
|
|
return len(self.conversations[conversation_id].turns)
|
||
|
|
|
||
|
|
# Act
|
||
|
|
manager = ConversationManager()
|
||
|
|
conv_id = "test_conv"
|
||
|
|
|
||
|
|
# Create conversation - use the local ConversationState class
|
||
|
|
conv_state = ConversationState(conv_id)
|
||
|
|
manager.conversations[conv_id] = conv_state
|
||
|
|
|
||
|
|
# Add turns
|
||
|
|
success1, turn1 = manager.add_turn(conv_id, "user", "Hello, what is 2+2?")
|
||
|
|
success2, turn2 = manager.add_turn(conv_id, "assistant", "2+2 equals 4.")
|
||
|
|
success3, turn3 = manager.add_turn(conv_id, "user", "What about 3+3?")
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert success1 is True
|
||
|
|
assert turn1.role == "user"
|
||
|
|
assert turn1.content == "Hello, what is 2+2?"
|
||
|
|
|
||
|
|
assert manager.get_turn_count(conv_id) == 3
|
||
|
|
|
||
|
|
recent_turns = manager.get_recent_turns(conv_id, limit=2)
|
||
|
|
assert len(recent_turns) == 2
|
||
|
|
assert recent_turns[0].role == "assistant"
|
||
|
|
assert recent_turns[1].role == "user"
|
||
|
|
|
||
|
|
def test_context_preservation(self):
|
||
|
|
"""Test preservation and retrieval of conversation context"""
|
||
|
|
# Arrange
|
||
|
|
class ContextManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.contexts = {}
|
||
|
|
|
||
|
|
def set_context(self, conversation_id, key, value, ttl_minutes=None):
|
||
|
|
"""Set context value with optional TTL"""
|
||
|
|
if conversation_id not in self.contexts:
|
||
|
|
self.contexts[conversation_id] = {}
|
||
|
|
|
||
|
|
context_entry = {
|
||
|
|
"value": value,
|
||
|
|
"created_at": datetime.now(),
|
||
|
|
"ttl_minutes": ttl_minutes
|
||
|
|
}
|
||
|
|
|
||
|
|
self.contexts[conversation_id][key] = context_entry
|
||
|
|
|
||
|
|
def get_context(self, conversation_id, key, default=None):
|
||
|
|
"""Get context value, respecting TTL"""
|
||
|
|
if conversation_id not in self.contexts:
|
||
|
|
return default
|
||
|
|
|
||
|
|
if key not in self.contexts[conversation_id]:
|
||
|
|
return default
|
||
|
|
|
||
|
|
entry = self.contexts[conversation_id][key]
|
||
|
|
|
||
|
|
# Check TTL
|
||
|
|
if entry["ttl_minutes"]:
|
||
|
|
age = datetime.now() - entry["created_at"]
|
||
|
|
if age > timedelta(minutes=entry["ttl_minutes"]):
|
||
|
|
# Expired
|
||
|
|
del self.contexts[conversation_id][key]
|
||
|
|
return default
|
||
|
|
|
||
|
|
return entry["value"]
|
||
|
|
|
||
|
|
def update_context(self, conversation_id, updates):
|
||
|
|
"""Update multiple context values"""
|
||
|
|
for key, value in updates.items():
|
||
|
|
self.set_context(conversation_id, key, value)
|
||
|
|
|
||
|
|
def clear_context(self, conversation_id, keys=None):
|
||
|
|
"""Clear specific keys or entire context"""
|
||
|
|
if conversation_id not in self.contexts:
|
||
|
|
return
|
||
|
|
|
||
|
|
if keys is None:
|
||
|
|
# Clear all context
|
||
|
|
self.contexts[conversation_id] = {}
|
||
|
|
else:
|
||
|
|
# Clear specific keys
|
||
|
|
for key in keys:
|
||
|
|
self.contexts[conversation_id].pop(key, None)
|
||
|
|
|
||
|
|
def get_all_context(self, conversation_id):
|
||
|
|
"""Get all context for conversation"""
|
||
|
|
if conversation_id not in self.contexts:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
# Filter out expired entries
|
||
|
|
valid_context = {}
|
||
|
|
for key, entry in self.contexts[conversation_id].items():
|
||
|
|
if entry["ttl_minutes"]:
|
||
|
|
age = datetime.now() - entry["created_at"]
|
||
|
|
if age <= timedelta(minutes=entry["ttl_minutes"]):
|
||
|
|
valid_context[key] = entry["value"]
|
||
|
|
else:
|
||
|
|
valid_context[key] = entry["value"]
|
||
|
|
|
||
|
|
return valid_context
|
||
|
|
|
||
|
|
# Act
|
||
|
|
context_manager = ContextManager()
|
||
|
|
conv_id = "test_conv"
|
||
|
|
|
||
|
|
# Set various context values
|
||
|
|
context_manager.set_context(conv_id, "user_name", "Alice")
|
||
|
|
context_manager.set_context(conv_id, "topic", "mathematics")
|
||
|
|
context_manager.set_context(conv_id, "temp_calculation", "2+2=4", ttl_minutes=1)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert context_manager.get_context(conv_id, "user_name") == "Alice"
|
||
|
|
assert context_manager.get_context(conv_id, "topic") == "mathematics"
|
||
|
|
assert context_manager.get_context(conv_id, "temp_calculation") == "2+2=4"
|
||
|
|
assert context_manager.get_context(conv_id, "nonexistent", "default") == "default"
|
||
|
|
|
||
|
|
# Test bulk updates
|
||
|
|
context_manager.update_context(conv_id, {
|
||
|
|
"calculation_count": 1,
|
||
|
|
"last_operation": "addition"
|
||
|
|
})
|
||
|
|
|
||
|
|
all_context = context_manager.get_all_context(conv_id)
|
||
|
|
assert "calculation_count" in all_context
|
||
|
|
assert "last_operation" in all_context
|
||
|
|
assert len(all_context) == 5
|
||
|
|
|
||
|
|
# Test clearing specific keys
|
||
|
|
context_manager.clear_context(conv_id, ["temp_calculation"])
|
||
|
|
assert context_manager.get_context(conv_id, "temp_calculation") is None
|
||
|
|
assert context_manager.get_context(conv_id, "user_name") == "Alice"
|
||
|
|
|
||
|
|
def test_multi_turn_reasoning_state(self):
|
||
|
|
"""Test state management for multi-turn reasoning"""
|
||
|
|
# Arrange
|
||
|
|
class ReasoningStateManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.reasoning_states = {}
|
||
|
|
|
||
|
|
def start_reasoning_session(self, conversation_id, question, reasoning_type="sequential"):
|
||
|
|
"""Start a new reasoning session"""
|
||
|
|
session_id = f"{conversation_id}_reasoning_{datetime.now().strftime('%H%M%S')}"
|
||
|
|
|
||
|
|
self.reasoning_states[session_id] = {
|
||
|
|
"conversation_id": conversation_id,
|
||
|
|
"original_question": question,
|
||
|
|
"reasoning_type": reasoning_type,
|
||
|
|
"status": "active",
|
||
|
|
"steps": [],
|
||
|
|
"intermediate_results": {},
|
||
|
|
"final_answer": None,
|
||
|
|
"created_at": datetime.now(),
|
||
|
|
"updated_at": datetime.now()
|
||
|
|
}
|
||
|
|
|
||
|
|
return session_id
|
||
|
|
|
||
|
|
def add_reasoning_step(self, session_id, step_type, content, tool_result=None):
|
||
|
|
"""Add a step to reasoning session"""
|
||
|
|
if session_id not in self.reasoning_states:
|
||
|
|
return False
|
||
|
|
|
||
|
|
step = {
|
||
|
|
"step_number": len(self.reasoning_states[session_id]["steps"]) + 1,
|
||
|
|
"step_type": step_type, # "think", "act", "observe"
|
||
|
|
"content": content,
|
||
|
|
"tool_result": tool_result,
|
||
|
|
"timestamp": datetime.now()
|
||
|
|
}
|
||
|
|
|
||
|
|
self.reasoning_states[session_id]["steps"].append(step)
|
||
|
|
self.reasoning_states[session_id]["updated_at"] = datetime.now()
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def set_intermediate_result(self, session_id, key, value):
|
||
|
|
"""Store intermediate result for later use"""
|
||
|
|
if session_id not in self.reasoning_states:
|
||
|
|
return False
|
||
|
|
|
||
|
|
self.reasoning_states[session_id]["intermediate_results"][key] = value
|
||
|
|
return True
|
||
|
|
|
||
|
|
def get_intermediate_result(self, session_id, key):
|
||
|
|
"""Retrieve intermediate result"""
|
||
|
|
if session_id not in self.reasoning_states:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return self.reasoning_states[session_id]["intermediate_results"].get(key)
|
||
|
|
|
||
|
|
def complete_reasoning_session(self, session_id, final_answer):
|
||
|
|
"""Mark reasoning session as complete"""
|
||
|
|
if session_id not in self.reasoning_states:
|
||
|
|
return False
|
||
|
|
|
||
|
|
self.reasoning_states[session_id]["final_answer"] = final_answer
|
||
|
|
self.reasoning_states[session_id]["status"] = "completed"
|
||
|
|
self.reasoning_states[session_id]["updated_at"] = datetime.now()
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def get_reasoning_summary(self, session_id):
|
||
|
|
"""Get summary of reasoning session"""
|
||
|
|
if session_id not in self.reasoning_states:
|
||
|
|
return None
|
||
|
|
|
||
|
|
state = self.reasoning_states[session_id]
|
||
|
|
return {
|
||
|
|
"original_question": state["original_question"],
|
||
|
|
"step_count": len(state["steps"]),
|
||
|
|
"status": state["status"],
|
||
|
|
"final_answer": state["final_answer"],
|
||
|
|
"reasoning_chain": [step["content"] for step in state["steps"] if step["step_type"] == "think"]
|
||
|
|
}
|
||
|
|
|
||
|
|
# Act
|
||
|
|
reasoning_manager = ReasoningStateManager()
|
||
|
|
conv_id = "test_conv"
|
||
|
|
|
||
|
|
# Start reasoning session
|
||
|
|
session_id = reasoning_manager.start_reasoning_session(
|
||
|
|
conv_id,
|
||
|
|
"What is the population of the capital of France?"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add reasoning steps
|
||
|
|
reasoning_manager.add_reasoning_step(session_id, "think", "I need to find the capital first")
|
||
|
|
reasoning_manager.add_reasoning_step(session_id, "act", "search for capital of France", "Paris")
|
||
|
|
reasoning_manager.set_intermediate_result(session_id, "capital", "Paris")
|
||
|
|
|
||
|
|
reasoning_manager.add_reasoning_step(session_id, "observe", "Found that Paris is the capital")
|
||
|
|
reasoning_manager.add_reasoning_step(session_id, "think", "Now I need to find Paris population")
|
||
|
|
reasoning_manager.add_reasoning_step(session_id, "act", "search for Paris population", "2.1 million")
|
||
|
|
|
||
|
|
reasoning_manager.complete_reasoning_session(session_id, "The population of Paris is approximately 2.1 million")
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert session_id.startswith(f"{conv_id}_reasoning_")
|
||
|
|
|
||
|
|
capital = reasoning_manager.get_intermediate_result(session_id, "capital")
|
||
|
|
assert capital == "Paris"
|
||
|
|
|
||
|
|
summary = reasoning_manager.get_reasoning_summary(session_id)
|
||
|
|
assert summary["original_question"] == "What is the population of the capital of France?"
|
||
|
|
assert summary["step_count"] == 5
|
||
|
|
assert summary["status"] == "completed"
|
||
|
|
assert "2.1 million" in summary["final_answer"]
|
||
|
|
assert len(summary["reasoning_chain"]) == 2 # Two "think" steps
|
||
|
|
|
||
|
|
def test_conversation_memory_management(self):
|
||
|
|
"""Test memory management for long conversations"""
|
||
|
|
# Arrange
|
||
|
|
class ConversationMemoryManager:
|
||
|
|
def __init__(self, max_turns=100, max_context_age_hours=24):
|
||
|
|
self.max_turns = max_turns
|
||
|
|
self.max_context_age_hours = max_context_age_hours
|
||
|
|
self.conversations = {}
|
||
|
|
|
||
|
|
def add_conversation_turn(self, conversation_id, role, content, metadata=None):
|
||
|
|
"""Add turn with automatic memory management"""
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
self.conversations[conversation_id] = {
|
||
|
|
"turns": [],
|
||
|
|
"context": {},
|
||
|
|
"created_at": datetime.now()
|
||
|
|
}
|
||
|
|
|
||
|
|
turn = {
|
||
|
|
"role": role,
|
||
|
|
"content": content,
|
||
|
|
"timestamp": datetime.now(),
|
||
|
|
"metadata": metadata or {}
|
||
|
|
}
|
||
|
|
|
||
|
|
self.conversations[conversation_id]["turns"].append(turn)
|
||
|
|
|
||
|
|
# Apply memory management
|
||
|
|
self._manage_memory(conversation_id)
|
||
|
|
|
||
|
|
def _manage_memory(self, conversation_id):
|
||
|
|
"""Apply memory management policies"""
|
||
|
|
conv = self.conversations[conversation_id]
|
||
|
|
|
||
|
|
# Limit turn count
|
||
|
|
if len(conv["turns"]) > self.max_turns:
|
||
|
|
# Keep recent turns and important summary turns
|
||
|
|
turns_to_keep = self.max_turns // 2
|
||
|
|
important_turns = self._identify_important_turns(conv["turns"])
|
||
|
|
recent_turns = conv["turns"][-turns_to_keep:]
|
||
|
|
|
||
|
|
# Combine important and recent turns, avoiding duplicates
|
||
|
|
kept_turns = []
|
||
|
|
seen_indices = set()
|
||
|
|
|
||
|
|
# Add important turns first
|
||
|
|
for turn_index, turn in important_turns:
|
||
|
|
if turn_index not in seen_indices:
|
||
|
|
kept_turns.append(turn)
|
||
|
|
seen_indices.add(turn_index)
|
||
|
|
|
||
|
|
# Add recent turns
|
||
|
|
for i, turn in enumerate(recent_turns):
|
||
|
|
original_index = len(conv["turns"]) - len(recent_turns) + i
|
||
|
|
if original_index not in seen_indices:
|
||
|
|
kept_turns.append(turn)
|
||
|
|
|
||
|
|
conv["turns"] = kept_turns[-self.max_turns:] # Final limit
|
||
|
|
|
||
|
|
# Clean old context
|
||
|
|
self._clean_old_context(conversation_id)
|
||
|
|
|
||
|
|
def _identify_important_turns(self, turns):
|
||
|
|
"""Identify important turns to preserve"""
|
||
|
|
important = []
|
||
|
|
|
||
|
|
for i, turn in enumerate(turns):
|
||
|
|
# Keep turns with high information content
|
||
|
|
if (len(turn["content"]) > 100 or
|
||
|
|
any(keyword in turn["content"].lower() for keyword in ["calculate", "result", "answer", "conclusion"])):
|
||
|
|
important.append((i, turn))
|
||
|
|
|
||
|
|
return important[:10] # Limit important turns
|
||
|
|
|
||
|
|
def _clean_old_context(self, conversation_id):
|
||
|
|
"""Remove old context entries"""
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
return
|
||
|
|
|
||
|
|
cutoff_time = datetime.now() - timedelta(hours=self.max_context_age_hours)
|
||
|
|
context = self.conversations[conversation_id]["context"]
|
||
|
|
|
||
|
|
keys_to_remove = []
|
||
|
|
for key, entry in context.items():
|
||
|
|
if isinstance(entry, dict) and "timestamp" in entry:
|
||
|
|
if entry["timestamp"] < cutoff_time:
|
||
|
|
keys_to_remove.append(key)
|
||
|
|
|
||
|
|
for key in keys_to_remove:
|
||
|
|
del context[key]
|
||
|
|
|
||
|
|
def get_conversation_summary(self, conversation_id):
|
||
|
|
"""Get summary of conversation state"""
|
||
|
|
if conversation_id not in self.conversations:
|
||
|
|
return None
|
||
|
|
|
||
|
|
conv = self.conversations[conversation_id]
|
||
|
|
return {
|
||
|
|
"turn_count": len(conv["turns"]),
|
||
|
|
"context_keys": list(conv["context"].keys()),
|
||
|
|
"age_hours": (datetime.now() - conv["created_at"]).total_seconds() / 3600,
|
||
|
|
"last_activity": conv["turns"][-1]["timestamp"] if conv["turns"] else None
|
||
|
|
}
|
||
|
|
|
||
|
|
# Act
|
||
|
|
memory_manager = ConversationMemoryManager(max_turns=5, max_context_age_hours=1)
|
||
|
|
conv_id = "test_memory_conv"
|
||
|
|
|
||
|
|
# Add many turns to test memory management
|
||
|
|
for i in range(10):
|
||
|
|
memory_manager.add_conversation_turn(
|
||
|
|
conv_id,
|
||
|
|
"user" if i % 2 == 0 else "assistant",
|
||
|
|
f"Turn {i}: {'Important calculation result' if i == 5 else 'Regular content'}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
summary = memory_manager.get_conversation_summary(conv_id)
|
||
|
|
assert summary["turn_count"] <= 5 # Should be limited
|
||
|
|
|
||
|
|
# Check that important turns are preserved
|
||
|
|
turns = memory_manager.conversations[conv_id]["turns"]
|
||
|
|
important_preserved = any("Important calculation" in turn["content"] for turn in turns)
|
||
|
|
assert important_preserved, "Important turns should be preserved"
|
||
|
|
|
||
|
|
def test_conversation_state_persistence(self):
|
||
|
|
"""Test serialization and deserialization of conversation state"""
|
||
|
|
# Arrange
|
||
|
|
class ConversationStatePersistence:
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def serialize_conversation(self, conversation_state):
|
||
|
|
"""Serialize conversation state to JSON-compatible format"""
|
||
|
|
def datetime_serializer(obj):
|
||
|
|
if isinstance(obj, datetime):
|
||
|
|
return obj.isoformat()
|
||
|
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||
|
|
|
||
|
|
return json.dumps(conversation_state, default=datetime_serializer, indent=2)
|
||
|
|
|
||
|
|
def deserialize_conversation(self, serialized_data):
|
||
|
|
"""Deserialize conversation state from JSON"""
|
||
|
|
def datetime_deserializer(data):
|
||
|
|
"""Convert ISO datetime strings back to datetime objects"""
|
||
|
|
if isinstance(data, dict):
|
||
|
|
for key, value in data.items():
|
||
|
|
if isinstance(value, str) and self._is_iso_datetime(value):
|
||
|
|
data[key] = datetime.fromisoformat(value)
|
||
|
|
elif isinstance(value, (dict, list)):
|
||
|
|
data[key] = datetime_deserializer(value)
|
||
|
|
elif isinstance(data, list):
|
||
|
|
for i, item in enumerate(data):
|
||
|
|
data[i] = datetime_deserializer(item)
|
||
|
|
|
||
|
|
return data
|
||
|
|
|
||
|
|
parsed_data = json.loads(serialized_data)
|
||
|
|
return datetime_deserializer(parsed_data)
|
||
|
|
|
||
|
|
def _is_iso_datetime(self, value):
|
||
|
|
"""Check if string is ISO datetime format"""
|
||
|
|
try:
|
||
|
|
datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||
|
|
return True
|
||
|
|
except (ValueError, AttributeError):
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Create sample conversation state
|
||
|
|
conversation_state = {
|
||
|
|
"conversation_id": "test_conv_123",
|
||
|
|
"user_id": "user456",
|
||
|
|
"created_at": datetime.now(),
|
||
|
|
"updated_at": datetime.now(),
|
||
|
|
"turns": [
|
||
|
|
{
|
||
|
|
"role": "user",
|
||
|
|
"content": "Hello",
|
||
|
|
"timestamp": datetime.now(),
|
||
|
|
"metadata": {}
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"role": "assistant",
|
||
|
|
"content": "Hi there!",
|
||
|
|
"timestamp": datetime.now(),
|
||
|
|
"metadata": {"confidence": 0.9}
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"context": {
|
||
|
|
"user_preference": "detailed_answers",
|
||
|
|
"topic": "general"
|
||
|
|
},
|
||
|
|
"metadata": {
|
||
|
|
"platform": "web",
|
||
|
|
"session_start": datetime.now()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Act
|
||
|
|
persistence = ConversationStatePersistence()
|
||
|
|
|
||
|
|
# Serialize
|
||
|
|
serialized = persistence.serialize_conversation(conversation_state)
|
||
|
|
assert isinstance(serialized, str)
|
||
|
|
assert "test_conv_123" in serialized
|
||
|
|
|
||
|
|
# Deserialize
|
||
|
|
deserialized = persistence.deserialize_conversation(serialized)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert deserialized["conversation_id"] == "test_conv_123"
|
||
|
|
assert deserialized["user_id"] == "user456"
|
||
|
|
assert isinstance(deserialized["created_at"], datetime)
|
||
|
|
assert len(deserialized["turns"]) == 2
|
||
|
|
assert deserialized["turns"][0]["role"] == "user"
|
||
|
|
assert isinstance(deserialized["turns"][0]["timestamp"], datetime)
|
||
|
|
assert deserialized["context"]["topic"] == "general"
|
||
|
|
assert deserialized["metadata"]["platform"] == "web"
|