async-semantic-llm-cache/tests/test_stats.py
2026-03-06 15:54:47 +01:00

443 lines
14 KiB
Python

"""Tests for statistics and analytics module."""
import time
import pytest
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.stats import (
CacheStats,
_stats_manager,
clear_cache,
export_cache,
get_stats,
invalidate,
warm_cache,
)
@pytest.fixture(autouse=True)
def clear_stats_state():
"""Clear stats state before each test."""
_stats_manager.clear_stats()
yield
class TestCacheStats:
"""Tests for CacheStats dataclass."""
def test_default_values(self):
"""Test CacheStats initializes with defaults."""
stats = CacheStats()
assert stats.hits == 0
assert stats.misses == 0
assert stats.total_saved_ms == 0.0
assert stats.estimated_savings_usd == 0.0
def test_hit_rate_empty(self):
"""Test hit rate with no requests."""
stats = CacheStats()
assert stats.hit_rate == 0.0
def test_hit_rate_all_hits(self):
"""Test hit rate with all cache hits."""
stats = CacheStats(hits=10, misses=0)
assert stats.hit_rate == 1.0
def test_hit_rate_all_misses(self):
"""Test hit rate with all cache misses."""
stats = CacheStats(hits=0, misses=10)
assert stats.hit_rate == 0.0
def test_hit_rate_mixed(self):
"""Test hit rate with mixed hits and misses."""
stats = CacheStats(hits=7, misses=3)
assert stats.hit_rate == 0.7
def test_total_requests(self):
"""Test total requests calculation."""
stats = CacheStats(hits=5, misses=3)
assert stats.total_requests == 8
def test_to_dict(self):
"""Test converting stats to dictionary."""
stats = CacheStats(hits=10, misses=5, total_saved_ms=1000.0, estimated_savings_usd=0.5)
result = stats.to_dict()
assert result["hits"] == 10
assert result["misses"] == 5
assert result["hit_rate"] == 2/3
assert result["total_requests"] == 15
assert result["total_saved_ms"] == 1000.0
assert result["estimated_savings_usd"] == 0.5
def test_iadd(self):
"""Test in-place addition of stats."""
stats1 = CacheStats(hits=5, misses=3, total_saved_ms=500.0, estimated_savings_usd=0.25)
stats2 = CacheStats(hits=3, misses=2, total_saved_ms=300.0, estimated_savings_usd=0.15)
stats1 += stats2
assert stats1.hits == 8
assert stats1.misses == 5
assert stats1.total_saved_ms == 800.0
assert stats1.estimated_savings_usd == 0.4
class TestStatsManager:
"""Tests for _StatsManager."""
def test_record_hit(self):
"""Test recording a cache hit."""
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0, saved_cost=0.01)
stats = _stats_manager.get_stats("test_ns")
assert stats.hits == 1
assert stats.total_saved_ms == 100.0
assert stats.estimated_savings_usd == 0.01
def test_record_multiple_hits(self):
"""Test recording multiple hits."""
_stats_manager.record_hit("test_ns", latency_saved_ms=50.0, saved_cost=0.005)
_stats_manager.record_hit("test_ns", latency_saved_ms=75.0, saved_cost=0.008)
stats = _stats_manager.get_stats("test_ns")
assert stats.hits == 2
assert stats.total_saved_ms == 125.0
def test_record_miss(self):
"""Test recording a cache miss."""
_stats_manager.record_miss("test_ns")
stats = _stats_manager.get_stats("test_ns")
assert stats.misses == 1
def test_get_stats_namespace(self):
"""Test getting stats for specific namespace."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_miss("ns1")
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
stats1 = _stats_manager.get_stats("ns1")
stats2 = _stats_manager.get_stats("ns2")
assert stats1.hits == 1
assert stats1.misses == 1
assert stats2.hits == 1
assert stats2.misses == 0
def test_get_stats_all_namespaces(self):
"""Test getting aggregated stats for all namespaces."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.record_miss("ns1")
stats = _stats_manager.get_stats(None) # All namespaces
assert stats.hits == 2
assert stats.misses == 1
def test_get_stats_nonexistent_namespace(self):
"""Test getting stats for namespace with no activity."""
stats = _stats_manager.get_stats("nonexistent")
assert stats.hits == 0
assert stats.misses == 0
def test_clear_stats_namespace(self):
"""Test clearing stats for specific namespace."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.clear_stats("ns1")
stats1 = _stats_manager.get_stats("ns1")
stats2 = _stats_manager.get_stats("ns2")
assert stats1.hits == 0
assert stats2.hits == 1
def test_clear_stats_all(self):
"""Test clearing all stats."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.clear_stats()
stats = _stats_manager.get_stats(None)
assert stats.hits == 0
assert stats.misses == 0
def test_set_backend(self):
"""Test setting default backend."""
custom_backend = MemoryBackend(max_size=10)
_stats_manager.set_backend(custom_backend)
retrieved = _stats_manager.get_backend()
assert retrieved is custom_backend
class TestPublicStatsAPI:
"""Tests for public stats API functions."""
def test_get_stats(self):
"""Test get_stats returns dictionary."""
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0)
stats = get_stats("test_ns")
assert isinstance(stats, dict)
assert "hits" in stats
assert "misses" in stats
assert "hit_rate" in stats
def test_clear_cache_all(self):
"""Test clear_cache clears all entries."""
backend = _stats_manager.get_backend()
# Add some entries
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
backend.set("key1", entry)
backend.set("key2", entry)
count = clear_cache()
assert count >= 0
def test_clear_cache_namespace(self):
"""Test clear_cache clears specific namespace."""
backend = _stats_manager.get_backend()
# Add entries in different namespaces
entry1 = CacheEntry(prompt="test1", response="r1", namespace="ns1", created_at=time.time())
entry2 = CacheEntry(prompt="test2", response="r2", namespace="ns2", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
count = clear_cache(namespace="ns1")
assert count >= 0
def test_invalidate_pattern(self):
"""Test invalidating entries by pattern."""
backend = _stats_manager.get_backend()
# Add entries with different prompts
entry1 = CacheEntry(prompt="Python programming", response="r1", created_at=time.time())
entry2 = CacheEntry(prompt="Rust programming", response="r2", created_at=time.time())
entry3 = CacheEntry(prompt="JavaScript", response="r3", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
backend.set("key3", entry3)
count = invalidate("Python")
assert count >= 0
def test_invalidate_case_insensitive(self):
"""Test invalidate is case insensitive."""
backend = _stats_manager.get_backend()
entry = CacheEntry(prompt="PYTHON programming", response="r1", created_at=time.time())
backend.set("key1", entry)
count = invalidate("python")
assert count >= 0
def test_invalidate_no_matches(self):
"""Test invalidate with no matches."""
backend = _stats_manager.get_backend()
entry = CacheEntry(prompt="Rust programming", response="r1", created_at=time.time())
backend.set("key1", entry)
count = invalidate("Python")
assert count == 0
def test_warm_cache(self):
"""Test warming cache with prompts."""
prompts = ["prompt1", "prompt2"]
def mock_llm(prompt: str) -> str:
return f"Response to: {prompt}"
count = warm_cache(prompts, mock_llm, namespace="warm_test")
assert count == len(prompts)
def test_warm_cache_with_failures(self):
"""Test warm_cache handles LLM failures gracefully."""
def failing_llm(prompt: str) -> str:
if "fail" in prompt:
raise ValueError("LLM error")
return f"Response to: {prompt}"
prompts = ["prompt1", "fail_prompt", "prompt3"]
count = warm_cache(prompts, failing_llm, namespace="warm_fail_test")
# Should return count even if some prompts fail
assert count == len(prompts)
class TestExportCache:
"""Tests for export_cache function."""
def test_export_all_entries(self, tmp_path):
"""Test exporting all cache entries."""
backend = _stats_manager.get_backend()
backend.clear()
# Add test entries
entry1 = CacheEntry(
prompt="test prompt 1",
response="response 1",
namespace="test_ns",
created_at=time.time(),
hit_count=5,
ttl=3600,
input_tokens=100,
output_tokens=50,
)
backend.set("key1", entry1)
entries = export_cache()
assert len(entries) >= 0
if entries:
assert "key" in entries[0]
assert "prompt" in entries[0]
assert "response" in entries[0]
assert "namespace" in entries[0]
assert "hit_count" in entries[0]
def test_export_namespace_filtered(self, tmp_path):
"""Test exporting entries filtered by namespace."""
backend = _stats_manager.get_backend()
backend.clear()
# Add entries in different namespaces
entry1 = CacheEntry(
prompt="test1", response="r1", namespace="ns1", created_at=time.time()
)
entry2 = CacheEntry(
prompt="test2", response="r2", namespace="ns2", created_at=time.time()
)
backend.set("key1", entry1)
backend.set("key2", entry2)
entries = export_cache(namespace="ns1")
# Should only return entries from ns1
assert all(e["namespace"] == "ns1" for e in entries)
def test_export_to_file(self, tmp_path):
"""Test exporting cache to JSON file."""
import json
filepath = tmp_path / "export.json"
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="test",
response="response",
namespace="test",
created_at=time.time(),
hit_count=3,
)
backend.set("key1", entry)
export_cache(filepath=str(filepath))
# Verify file was created and is valid JSON
assert filepath.exists()
with open(filepath) as f:
data = json.load(f)
assert isinstance(data, list)
def test_export_truncates_large_responses(self, tmp_path):
"""Test that large responses are truncated in export."""
backend = _stats_manager.get_backend()
backend.clear()
# Create entry with very large response
large_response = "x" * 2000
entry = CacheEntry(
prompt="test",
response=large_response,
created_at=time.time(),
)
backend.set("key1", entry)
entries = export_cache()
if entries:
# Response should be truncated to 1000 chars
assert len(entries[0]["response"]) <= 1000
class TestStatsIntegration:
"""Integration tests for stats with actual cache operations."""
def test_stats_tracking_with_cache_decorator(self):
"""Test that stats are tracked during cache operations."""
from semantic_llm_cache import cache
backend = MemoryBackend()
_stats_manager.clear_stats("integration_test")
@cache(backend=backend, namespace="integration_test")
def cached_func(prompt: str) -> str:
return f"Response to: {prompt}"
# Generate activity
cached_func("prompt1")
cached_func("prompt1") # Hit
cached_func("prompt2")
stats = get_stats("integration_test")
assert stats["total_requests"] >= 2
def test_cache_invalidate_integration(self):
"""Test invalidate removes entries from backend."""
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="Python is great",
response="Yes, it is!",
created_at=time.time(),
)
backend.set("key1", entry)
# Verify entry exists
assert backend.get("key1") is not None
# Invalidate
invalidate("Python")
# Entry should be gone
assert backend.get("key1") is None
def test_export_includes_metadata(self, tmp_path):
"""Test export includes all metadata fields."""
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="test prompt",
response="test response",
namespace="export_test",
created_at=time.time(),
ttl=7200,
hit_count=10,
input_tokens=500,
output_tokens=250,
embedding=[0.1, 0.2, 0.3],
)
backend.set("key1", entry)
entries = export_cache(namespace="export_test")
if entries:
e = entries[0]
assert "created_at" in e
assert "ttl" in e
assert "hit_count" in e
assert "input_tokens" in e
assert "output_tokens" in e