"""Tests for core cache decorator and API.""" import time from unittest.mock import MagicMock import pytest from semantic_llm_cache import CacheContext, CachedLLM, cache, get_default_backend, set_default_backend from semantic_llm_cache.backends import MemoryBackend from semantic_llm_cache.config import CacheConfig, CacheEntry from semantic_llm_cache.exceptions import PromptCacheError class TestCacheDecorator: """Tests for @cache decorator.""" def test_exact_match_cache_hit(self, mock_llm_func): """Test exact match caching returns cached result.""" call_count = {"count": 0} @cache() def cached_func(prompt: str) -> str: call_count["count"] += 1 return mock_llm_func(prompt) # First call - cache miss result1 = cached_func("What is Python?") assert result1 == "Python is a programming language." assert call_count["count"] == 1 # Second call - cache hit result2 = cached_func("What is Python?") assert result2 == "Python is a programming language." assert call_count["count"] == 1 # No additional call def test_cache_miss_different_prompt(self, mock_llm_func): """Test different prompts result in cache misses.""" @cache() def cached_func(prompt: str) -> str: return mock_llm_func(prompt) result1 = cached_func("What is Python?") result2 = cached_func("What is Rust?") assert result1 == "Python is a programming language." assert result2 == "Rust is a systems programming language." def test_cache_disabled(self, mock_llm_func): """Test caching can be disabled.""" call_count = {"count": 0} @cache(enabled=False) def cached_func(prompt: str) -> str: call_count["count"] += 1 return mock_llm_func(prompt) cached_func("What is Python?") cached_func("What is Python?") assert call_count["count"] == 2 # Both calls hit the function def test_custom_namespace(self, mock_llm_func): """Test custom namespace isolates cache.""" @cache(namespace="test") def cached_func(prompt: str) -> str: return mock_llm_func(prompt) result = cached_func("What is Python?") assert result == "Python is a programming language." def test_ttl_expiration(self, mock_llm_func): """Test TTL expiration works.""" @cache(ttl=1) # 1 second TTL def cached_func(prompt: str) -> str: return mock_llm_func(prompt) # First call result1 = cached_func("What is Python?") assert result1 == "Python is a programming language." # Immediate second call - should hit cache result2 = cached_func("What is Python?") assert result2 == "Python is a programming language." # Wait for expiration time.sleep(1.1) # Third call - should miss due to TTL # Note: This test may be flaky in slow CI environments result3 = cached_func("What is Python?") assert result3 == "Python is a programming language." def test_cache_with_exception(self): """Test cache handles exceptions properly.""" @cache() def failing_func(prompt: str) -> str: raise ValueError("LLM API error") with pytest.raises(PromptCacheError): failing_func("test prompt") def test_semantic_similarity_match(self, mock_llm_func): """Test semantic similarity matching.""" call_count = {"count": 0} @cache(similarity=0.85) def cached_func(prompt: str) -> str: call_count["count"] += 1 return mock_llm_func(prompt) # First call cached_func("What is Python?") assert call_count["count"] == 1 # Similar prompts may hit cache depending on embedding # Note: With dummy embeddings, exact string matching determines similarity cached_func("What is Python?") # Exact match assert call_count["count"] == 1 class TestCacheContext: """Tests for CacheContext manager.""" def test_context_manager(self): """Test CacheContext works as context manager.""" with CacheContext(similarity=0.9, ttl=1800) as ctx: assert ctx.config.similarity_threshold == 0.9 assert ctx.config.ttl == 1800 def test_context_stats(self): """Test context tracks stats.""" with CacheContext() as ctx: stats = ctx.stats assert "hits" in stats assert "misses" in stats class TestCachedLLM: """Tests for CachedLLM wrapper class.""" def test_init(self): """Test CachedLLM initialization.""" llm = CachedLLM(provider="openai", model="gpt-4") assert llm._provider == "openai" assert llm._model == "gpt-4" def test_chat_with_llm_func(self): """Test chat method with custom LLM function.""" llm = CachedLLM() def mock_llm(prompt: str) -> str: return f"Response to: {prompt}" result = llm.chat("Hello", llm_func=mock_llm) assert result == "Response to: Hello" def test_chat_caches_responses(self, mock_llm_func): """Test chat caches responses.""" llm = CachedLLM() call_count = {"count": 0} def counting_llm(prompt: str) -> str: call_count["count"] += 1 return mock_llm_func(prompt) llm.chat("What is Python?", llm_func=counting_llm) llm.chat("What is Python?", llm_func=counting_llm) # Should cache (depends on embedding, may not with dummy) assert call_count["count"] >= 1 class TestBackendManagement: """Tests for backend management functions.""" def test_get_default_backend(self): """Test get_default_backend returns a backend.""" backend = get_default_backend() assert backend is not None assert isinstance(backend, MemoryBackend) def test_set_default_backend(self): """Test set_default_backend changes default.""" custom_backend = MemoryBackend(max_size=10) set_default_backend(custom_backend) backend = get_default_backend() assert backend is custom_backend class TestCacheEntry: """Tests for CacheEntry class.""" def test_entry_creation(self): """Test CacheEntry creation.""" entry = CacheEntry( prompt="test", response="response", created_at=time.time(), ) assert entry.prompt == "test" assert entry.response == "response" def test_is_expired_no_ttl(self): """Test entry without TTL never expires.""" entry = CacheEntry( prompt="test", response="response", ttl=None, created_at=time.time() - 1000, ) assert not entry.is_expired(time.time()) def test_is_expired_with_ttl(self): """Test entry with TTL expires correctly.""" entry = CacheEntry( prompt="test", response="response", ttl=1, # 1 second created_at=time.time() - 2, # 2 seconds ago ) assert entry.is_expired(time.time()) def test_estimate_cost(self): """Test cost estimation.""" entry = CacheEntry( prompt="test", response="response", input_tokens=100, output_tokens=50, ) cost = entry.estimate_cost(0.001, 0.002) # 100/1000 * 0.001 + 50/1000 * 0.002 = 0.0001 + 0.0001 = 0.0002 assert abs(cost - 0.0002) < 1e-6 class TestCacheConfig: """Tests for CacheConfig class.""" def test_default_config(self): """Test default configuration values.""" config = CacheConfig() assert config.similarity_threshold == 1.0 assert config.ttl == 3600 assert config.namespace == "default" assert config.enabled is True def test_custom_config(self): """Test custom configuration values.""" config = CacheConfig( similarity_threshold=0.9, ttl=7200, namespace="custom", enabled=False, ) assert config.similarity_threshold == 0.9 assert config.ttl == 7200 assert config.namespace == "custom" assert config.enabled is False def test_invalid_similarity(self): """Test invalid similarity raises error.""" with pytest.raises(ValueError, match="similarity_threshold"): CacheConfig(similarity_threshold=1.5) def test_invalid_ttl(self): """Test invalid TTL raises error.""" with pytest.raises(ValueError, match="ttl"): CacheConfig(ttl=-1) def test_invalid_max_size(self): """Test invalid max_size raises error.""" with pytest.raises(ValueError, match="max_cache_size"): CacheConfig(max_cache_size=0) class TestCacheDecoratorEdgeCases: """Tests for edge cases in cache decorator.""" def test_cache_with_kwargs_only(self, mock_llm_func): """Test caching when function is called with kwargs only.""" @cache() def cached_func(prompt: str) -> str: return mock_llm_func(prompt) result = cached_func(prompt="What is Python?") assert result == "Python is a programming language." def test_cache_with_multiple_args(self): """Test caching with multiple arguments.""" call_count = {"count": 0} @cache() def cached_func(prompt: str, temperature: float = 0.7) -> str: call_count["count"] += 1 return f"Response to: {prompt} at {temperature}" cached_func("test", 0.5) cached_func("test", 0.5) # Same prompt hits cache even with different temperature # (cache key is based on first arg) cached_func("test", 0.9) cached_func("different", 0.5) # First call + different prompt = 2 calls assert call_count["count"] == 2 def test_cache_with_custom_key_func(self): """Test custom key function.""" call_count = {"count": 0} def custom_key(prompt: str, temperature: float = 0.7) -> str: return f"{prompt}:{temperature}" @cache(key_func=custom_key) def cached_func(prompt: str, temperature: float = 0.7) -> str: call_count["count"] += 1 return f"Response to: {prompt} at {temperature}" cached_func("test", 0.7) cached_func("test", 0.7) assert call_count["count"] == 1 def test_semantic_match_threshold_edge(self, mock_llm_func): """Test semantic matching at threshold boundaries.""" call_count = {"count": 0} @cache(similarity=0.5) # Lower threshold def cached_func(prompt: str) -> str: call_count["count"] += 1 return mock_llm_func(prompt) # First call cached_func("What is Python?") # Similar query may hit with lower threshold cached_func("What is Python?") # Exact match always hits assert call_count["count"] == 1 def test_cache_with_none_response(self): """Test caching None responses.""" call_count = {"count": 0} @cache() def cached_func(prompt: str) -> None: call_count["count"] += 1 return None result1 = cached_func("test") result2 = cached_func("test") assert result1 is None assert result2 is None assert call_count["count"] == 1 # Should cache def test_cache_with_empty_string_response(self): """Test caching empty string responses.""" call_count = {"count": 0} @cache() def cached_func(prompt: str) -> str: call_count["count"] += 1 return "" result1 = cached_func("test") result2 = cached_func("test") assert result1 == "" assert result2 == "" assert call_count["count"] == 1 # Should cache def test_cache_with_dict_response(self): """Test caching dict responses.""" call_count = {"count": 0} @cache() def cached_func(prompt: str) -> dict: call_count["count"] += 1 return {"key": "value", "number": 42} result1 = cached_func("test") cached_func("test") # Second call to verify caching assert result1 == {"key": "value", "number": 42} assert call_count["count"] == 1 def test_cache_with_list_response(self): """Test caching list responses.""" call_count = {"count": 0} @cache() def cached_func(prompt: str) -> list: call_count["count"] += 1 return [1, 2, 3, 4, 5] result1 = cached_func("test") cached_func("test") # Second call to verify caching assert result1 == [1, 2, 3, 4, 5] assert call_count["count"] == 1 class TestCacheDecoratorErrorPaths: """Tests for error handling in cache decorator.""" def test_backend_set_raises_error(self): """Test that backend.set errors propagate.""" from semantic_llm_cache.exceptions import CacheBackendError backend = MagicMock() # Backend set wraps exceptions in CacheBackendError backend.set.side_effect = CacheBackendError("Storage error") backend.get.return_value = None # No cached value @cache(backend=backend) def cached_func(prompt: str) -> str: return "response" # The CacheBackendError from backend.set should propagate with pytest.raises(CacheBackendError, match="Storage error"): cached_func("test") def test_backend_get_raises_error(self): """Test that backend.get errors propagate.""" from semantic_llm_cache.exceptions import CacheBackendError backend = MagicMock() # Backend get wraps exceptions in CacheBackendError backend.get.side_effect = CacheBackendError("Get error") @cache(backend=backend) def cached_func(prompt: str) -> str: return "response" # The CacheBackendError from backend.get should propagate with pytest.raises(CacheBackendError, match="Get error"): cached_func("test") def test_llm_error_still_wrapped(self): """Test that LLM errors are still wrapped in PromptCacheError.""" from semantic_llm_cache.exceptions import PromptCacheError @cache() def failing_func(prompt: str) -> str: raise ValueError("LLM API error") with pytest.raises(PromptCacheError): failing_func("test") class TestCacheContextAdvanced: """Advanced tests for CacheContext.""" def test_context_with_zero_similarity(self): """Test context with zero similarity (accept all).""" with CacheContext(similarity=0.0) as ctx: assert ctx.config.similarity_threshold == 0.0 def test_context_with_infinite_ttl(self): """Test context with infinite TTL (None).""" with CacheContext(ttl=None) as ctx: assert ctx.config.ttl is None def test_context_disabled(self): """Test disabled context.""" with CacheContext(enabled=False) as ctx: assert ctx.config.enabled is False class TestCachedLLMAdvanced: """Advanced tests for CachedLLM.""" def test_chat_with_kwargs(self): """Test chat with additional kwargs passed to LLM.""" llm = CachedLLM() def mock_llm(prompt: str, temperature: float = 0.7, max_tokens: int = 100) -> str: return f"Response to: {prompt} (temp={temperature}, tokens={max_tokens})" result = llm.chat("Hello", llm_func=mock_llm, temperature=0.5, max_tokens=200) assert "temp=0.5" in result assert "tokens=200" in result def test_cached_llm_with_custom_backend(self): """Test CachedLLM with custom backend.""" custom_backend = MemoryBackend(max_size=5) llm = CachedLLM(backend=custom_backend) assert llm._backend is custom_backend def test_cached_llm_different_namespaces(self): """Test CachedLLM with different namespaces.""" llm1 = CachedLLM(namespace="ns1") llm2 = CachedLLM(namespace="ns2") assert llm1._config.namespace == "ns1" assert llm2._config.namespace == "ns2" class TestBackendManagementAdvanced: """Advanced tests for backend management.""" def test_multiple_default_backend_changes(self): """Test changing default backend multiple times.""" backend1 = MemoryBackend(max_size=10) backend2 = MemoryBackend(max_size=20) backend3 = MemoryBackend(max_size=30) set_default_backend(backend1) assert get_default_backend() is backend1 set_default_backend(backend2) assert get_default_backend() is backend2 set_default_backend(backend3) assert get_default_backend() is backend3 def test_backend_persists_stats(self): """Test backend stats persist across get_default_backend calls.""" backend = get_default_backend() # Create an entry from semantic_llm_cache.config import CacheEntry entry = CacheEntry( prompt="test", response="response", created_at=time.time() ) backend.set("key1", entry) # Get stats stats = backend.get_stats() assert stats["size"] == 1 class TestCacheEntryEdgeCases: """Edge case tests for CacheEntry.""" def test_entry_with_zero_tokens(self): """Test entry with zero token counts.""" entry = CacheEntry( prompt="test", response="response", input_tokens=0, output_tokens=0, ) cost = entry.estimate_cost(0.001, 0.002) assert cost == 0.0 def test_entry_with_large_token_count(self): """Test entry with large token counts.""" entry = CacheEntry( prompt="test", response="response", input_tokens=100000, output_tokens=50000, ) cost = entry.estimate_cost(0.001, 0.002) assert cost > 0 def test_entry_with_negative_ttl(self): """Test entry creation handles negative TTL (becomes expired).""" entry = CacheEntry( prompt="test", response="response", ttl=-1, created_at=time.time(), ) # Negative TTL means immediately expired assert entry.is_expired(time.time()) def test_entry_hit_count_initialization(self): """Test entry initializes with zero hit count.""" entry = CacheEntry( prompt="test", response="response", created_at=time.time(), ) assert entry.hit_count == 0 class TestCacheConfigEdgeCases: """Edge case tests for CacheConfig.""" def test_config_boundary_similarity(self): """Test similarity at valid boundaries.""" config1 = CacheConfig(similarity_threshold=0.0) assert config1.similarity_threshold == 0.0 config2 = CacheConfig(similarity_threshold=1.0) assert config2.similarity_threshold == 1.0 def test_config_zero_ttl(self): """Test zero TTL is rejected (validation requires positive).""" # The validation in CacheConfig rejects ttl <= 0 with pytest.raises(ValueError, match="ttl"): CacheConfig(ttl=0) def test_config_very_large_ttl(self): """Test very large TTL.""" config = CacheConfig(ttl=86400 * 365) # 1 year assert config.ttl == 86400 * 365 def test_config_with_special_namespace(self): """Test namespace with special characters.""" config = CacheConfig(namespace="test-ns_123.v1") assert config.namespace == "test-ns_123.v1"