443 lines
14 KiB
Python
443 lines
14 KiB
Python
"""Tests for statistics and analytics module."""
|
|
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from semantic_llm_cache.backends import MemoryBackend
|
|
from semantic_llm_cache.config import CacheEntry
|
|
from semantic_llm_cache.stats import (
|
|
CacheStats,
|
|
_stats_manager,
|
|
clear_cache,
|
|
export_cache,
|
|
get_stats,
|
|
invalidate,
|
|
warm_cache,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_stats_state():
|
|
"""Clear stats state before each test."""
|
|
_stats_manager.clear_stats()
|
|
yield
|
|
|
|
|
|
class TestCacheStats:
|
|
"""Tests for CacheStats dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test CacheStats initializes with defaults."""
|
|
stats = CacheStats()
|
|
assert stats.hits == 0
|
|
assert stats.misses == 0
|
|
assert stats.total_saved_ms == 0.0
|
|
assert stats.estimated_savings_usd == 0.0
|
|
|
|
def test_hit_rate_empty(self):
|
|
"""Test hit rate with no requests."""
|
|
stats = CacheStats()
|
|
assert stats.hit_rate == 0.0
|
|
|
|
def test_hit_rate_all_hits(self):
|
|
"""Test hit rate with all cache hits."""
|
|
stats = CacheStats(hits=10, misses=0)
|
|
assert stats.hit_rate == 1.0
|
|
|
|
def test_hit_rate_all_misses(self):
|
|
"""Test hit rate with all cache misses."""
|
|
stats = CacheStats(hits=0, misses=10)
|
|
assert stats.hit_rate == 0.0
|
|
|
|
def test_hit_rate_mixed(self):
|
|
"""Test hit rate with mixed hits and misses."""
|
|
stats = CacheStats(hits=7, misses=3)
|
|
assert stats.hit_rate == 0.7
|
|
|
|
def test_total_requests(self):
|
|
"""Test total requests calculation."""
|
|
stats = CacheStats(hits=5, misses=3)
|
|
assert stats.total_requests == 8
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting stats to dictionary."""
|
|
stats = CacheStats(hits=10, misses=5, total_saved_ms=1000.0, estimated_savings_usd=0.5)
|
|
result = stats.to_dict()
|
|
|
|
assert result["hits"] == 10
|
|
assert result["misses"] == 5
|
|
assert result["hit_rate"] == 2/3
|
|
assert result["total_requests"] == 15
|
|
assert result["total_saved_ms"] == 1000.0
|
|
assert result["estimated_savings_usd"] == 0.5
|
|
|
|
def test_iadd(self):
|
|
"""Test in-place addition of stats."""
|
|
stats1 = CacheStats(hits=5, misses=3, total_saved_ms=500.0, estimated_savings_usd=0.25)
|
|
stats2 = CacheStats(hits=3, misses=2, total_saved_ms=300.0, estimated_savings_usd=0.15)
|
|
|
|
stats1 += stats2
|
|
|
|
assert stats1.hits == 8
|
|
assert stats1.misses == 5
|
|
assert stats1.total_saved_ms == 800.0
|
|
assert stats1.estimated_savings_usd == 0.4
|
|
|
|
|
|
class TestStatsManager:
|
|
"""Tests for _StatsManager."""
|
|
|
|
def test_record_hit(self):
|
|
"""Test recording a cache hit."""
|
|
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0, saved_cost=0.01)
|
|
|
|
stats = _stats_manager.get_stats("test_ns")
|
|
assert stats.hits == 1
|
|
assert stats.total_saved_ms == 100.0
|
|
assert stats.estimated_savings_usd == 0.01
|
|
|
|
def test_record_multiple_hits(self):
|
|
"""Test recording multiple hits."""
|
|
_stats_manager.record_hit("test_ns", latency_saved_ms=50.0, saved_cost=0.005)
|
|
_stats_manager.record_hit("test_ns", latency_saved_ms=75.0, saved_cost=0.008)
|
|
|
|
stats = _stats_manager.get_stats("test_ns")
|
|
assert stats.hits == 2
|
|
assert stats.total_saved_ms == 125.0
|
|
|
|
def test_record_miss(self):
|
|
"""Test recording a cache miss."""
|
|
_stats_manager.record_miss("test_ns")
|
|
|
|
stats = _stats_manager.get_stats("test_ns")
|
|
assert stats.misses == 1
|
|
|
|
def test_get_stats_namespace(self):
|
|
"""Test getting stats for specific namespace."""
|
|
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
|
_stats_manager.record_miss("ns1")
|
|
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
|
|
|
stats1 = _stats_manager.get_stats("ns1")
|
|
stats2 = _stats_manager.get_stats("ns2")
|
|
|
|
assert stats1.hits == 1
|
|
assert stats1.misses == 1
|
|
assert stats2.hits == 1
|
|
assert stats2.misses == 0
|
|
|
|
def test_get_stats_all_namespaces(self):
|
|
"""Test getting aggregated stats for all namespaces."""
|
|
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
|
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
|
_stats_manager.record_miss("ns1")
|
|
|
|
stats = _stats_manager.get_stats(None) # All namespaces
|
|
assert stats.hits == 2
|
|
assert stats.misses == 1
|
|
|
|
def test_get_stats_nonexistent_namespace(self):
|
|
"""Test getting stats for namespace with no activity."""
|
|
stats = _stats_manager.get_stats("nonexistent")
|
|
assert stats.hits == 0
|
|
assert stats.misses == 0
|
|
|
|
def test_clear_stats_namespace(self):
|
|
"""Test clearing stats for specific namespace."""
|
|
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
|
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
|
|
|
_stats_manager.clear_stats("ns1")
|
|
|
|
stats1 = _stats_manager.get_stats("ns1")
|
|
stats2 = _stats_manager.get_stats("ns2")
|
|
|
|
assert stats1.hits == 0
|
|
assert stats2.hits == 1
|
|
|
|
def test_clear_stats_all(self):
|
|
"""Test clearing all stats."""
|
|
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
|
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
|
|
|
_stats_manager.clear_stats()
|
|
|
|
stats = _stats_manager.get_stats(None)
|
|
assert stats.hits == 0
|
|
assert stats.misses == 0
|
|
|
|
def test_set_backend(self):
|
|
"""Test setting default backend."""
|
|
custom_backend = MemoryBackend(max_size=10)
|
|
_stats_manager.set_backend(custom_backend)
|
|
|
|
retrieved = _stats_manager.get_backend()
|
|
assert retrieved is custom_backend
|
|
|
|
|
|
class TestPublicStatsAPI:
|
|
"""Tests for public stats API functions."""
|
|
|
|
def test_get_stats(self):
|
|
"""Test get_stats returns dictionary."""
|
|
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0)
|
|
|
|
stats = get_stats("test_ns")
|
|
assert isinstance(stats, dict)
|
|
assert "hits" in stats
|
|
assert "misses" in stats
|
|
assert "hit_rate" in stats
|
|
|
|
def test_clear_cache_all(self):
|
|
"""Test clear_cache clears all entries."""
|
|
backend = _stats_manager.get_backend()
|
|
|
|
# Add some entries
|
|
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
|
|
backend.set("key1", entry)
|
|
backend.set("key2", entry)
|
|
|
|
count = clear_cache()
|
|
assert count >= 0
|
|
|
|
def test_clear_cache_namespace(self):
|
|
"""Test clear_cache clears specific namespace."""
|
|
backend = _stats_manager.get_backend()
|
|
|
|
# Add entries in different namespaces
|
|
entry1 = CacheEntry(prompt="test1", response="r1", namespace="ns1", created_at=time.time())
|
|
entry2 = CacheEntry(prompt="test2", response="r2", namespace="ns2", created_at=time.time())
|
|
|
|
backend.set("key1", entry1)
|
|
backend.set("key2", entry2)
|
|
|
|
count = clear_cache(namespace="ns1")
|
|
assert count >= 0
|
|
|
|
def test_invalidate_pattern(self):
|
|
"""Test invalidating entries by pattern."""
|
|
backend = _stats_manager.get_backend()
|
|
|
|
# Add entries with different prompts
|
|
entry1 = CacheEntry(prompt="Python programming", response="r1", created_at=time.time())
|
|
entry2 = CacheEntry(prompt="Rust programming", response="r2", created_at=time.time())
|
|
entry3 = CacheEntry(prompt="JavaScript", response="r3", created_at=time.time())
|
|
|
|
backend.set("key1", entry1)
|
|
backend.set("key2", entry2)
|
|
backend.set("key3", entry3)
|
|
|
|
count = invalidate("Python")
|
|
assert count >= 0
|
|
|
|
def test_invalidate_case_insensitive(self):
|
|
"""Test invalidate is case insensitive."""
|
|
backend = _stats_manager.get_backend()
|
|
|
|
entry = CacheEntry(prompt="PYTHON programming", response="r1", created_at=time.time())
|
|
backend.set("key1", entry)
|
|
|
|
count = invalidate("python")
|
|
assert count >= 0
|
|
|
|
def test_invalidate_no_matches(self):
|
|
"""Test invalidate with no matches."""
|
|
backend = _stats_manager.get_backend()
|
|
|
|
entry = CacheEntry(prompt="Rust programming", response="r1", created_at=time.time())
|
|
backend.set("key1", entry)
|
|
|
|
count = invalidate("Python")
|
|
assert count == 0
|
|
|
|
def test_warm_cache(self):
|
|
"""Test warming cache with prompts."""
|
|
prompts = ["prompt1", "prompt2"]
|
|
|
|
def mock_llm(prompt: str) -> str:
|
|
return f"Response to: {prompt}"
|
|
|
|
count = warm_cache(prompts, mock_llm, namespace="warm_test")
|
|
assert count == len(prompts)
|
|
|
|
def test_warm_cache_with_failures(self):
|
|
"""Test warm_cache handles LLM failures gracefully."""
|
|
|
|
def failing_llm(prompt: str) -> str:
|
|
if "fail" in prompt:
|
|
raise ValueError("LLM error")
|
|
return f"Response to: {prompt}"
|
|
|
|
prompts = ["prompt1", "fail_prompt", "prompt3"]
|
|
count = warm_cache(prompts, failing_llm, namespace="warm_fail_test")
|
|
# Should return count even if some prompts fail
|
|
assert count == len(prompts)
|
|
|
|
|
|
class TestExportCache:
|
|
"""Tests for export_cache function."""
|
|
|
|
def test_export_all_entries(self, tmp_path):
|
|
"""Test exporting all cache entries."""
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
# Add test entries
|
|
entry1 = CacheEntry(
|
|
prompt="test prompt 1",
|
|
response="response 1",
|
|
namespace="test_ns",
|
|
created_at=time.time(),
|
|
hit_count=5,
|
|
ttl=3600,
|
|
input_tokens=100,
|
|
output_tokens=50,
|
|
)
|
|
backend.set("key1", entry1)
|
|
|
|
entries = export_cache()
|
|
assert len(entries) >= 0
|
|
|
|
if entries:
|
|
assert "key" in entries[0]
|
|
assert "prompt" in entries[0]
|
|
assert "response" in entries[0]
|
|
assert "namespace" in entries[0]
|
|
assert "hit_count" in entries[0]
|
|
|
|
def test_export_namespace_filtered(self, tmp_path):
|
|
"""Test exporting entries filtered by namespace."""
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
# Add entries in different namespaces
|
|
entry1 = CacheEntry(
|
|
prompt="test1", response="r1", namespace="ns1", created_at=time.time()
|
|
)
|
|
entry2 = CacheEntry(
|
|
prompt="test2", response="r2", namespace="ns2", created_at=time.time()
|
|
)
|
|
|
|
backend.set("key1", entry1)
|
|
backend.set("key2", entry2)
|
|
|
|
entries = export_cache(namespace="ns1")
|
|
# Should only return entries from ns1
|
|
assert all(e["namespace"] == "ns1" for e in entries)
|
|
|
|
def test_export_to_file(self, tmp_path):
|
|
"""Test exporting cache to JSON file."""
|
|
import json
|
|
|
|
filepath = tmp_path / "export.json"
|
|
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
entry = CacheEntry(
|
|
prompt="test",
|
|
response="response",
|
|
namespace="test",
|
|
created_at=time.time(),
|
|
hit_count=3,
|
|
)
|
|
backend.set("key1", entry)
|
|
|
|
export_cache(filepath=str(filepath))
|
|
|
|
# Verify file was created and is valid JSON
|
|
assert filepath.exists()
|
|
with open(filepath) as f:
|
|
data = json.load(f)
|
|
assert isinstance(data, list)
|
|
|
|
def test_export_truncates_large_responses(self, tmp_path):
|
|
"""Test that large responses are truncated in export."""
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
# Create entry with very large response
|
|
large_response = "x" * 2000
|
|
entry = CacheEntry(
|
|
prompt="test",
|
|
response=large_response,
|
|
created_at=time.time(),
|
|
)
|
|
backend.set("key1", entry)
|
|
|
|
entries = export_cache()
|
|
if entries:
|
|
# Response should be truncated to 1000 chars
|
|
assert len(entries[0]["response"]) <= 1000
|
|
|
|
|
|
class TestStatsIntegration:
|
|
"""Integration tests for stats with actual cache operations."""
|
|
|
|
def test_stats_tracking_with_cache_decorator(self):
|
|
"""Test that stats are tracked during cache operations."""
|
|
from semantic_llm_cache import cache
|
|
|
|
backend = MemoryBackend()
|
|
_stats_manager.clear_stats("integration_test")
|
|
|
|
@cache(backend=backend, namespace="integration_test")
|
|
def cached_func(prompt: str) -> str:
|
|
return f"Response to: {prompt}"
|
|
|
|
# Generate activity
|
|
cached_func("prompt1")
|
|
cached_func("prompt1") # Hit
|
|
cached_func("prompt2")
|
|
|
|
stats = get_stats("integration_test")
|
|
assert stats["total_requests"] >= 2
|
|
|
|
def test_cache_invalidate_integration(self):
|
|
"""Test invalidate removes entries from backend."""
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
entry = CacheEntry(
|
|
prompt="Python is great",
|
|
response="Yes, it is!",
|
|
created_at=time.time(),
|
|
)
|
|
backend.set("key1", entry)
|
|
|
|
# Verify entry exists
|
|
assert backend.get("key1") is not None
|
|
|
|
# Invalidate
|
|
invalidate("Python")
|
|
|
|
# Entry should be gone
|
|
assert backend.get("key1") is None
|
|
|
|
def test_export_includes_metadata(self, tmp_path):
|
|
"""Test export includes all metadata fields."""
|
|
backend = _stats_manager.get_backend()
|
|
backend.clear()
|
|
|
|
entry = CacheEntry(
|
|
prompt="test prompt",
|
|
response="test response",
|
|
namespace="export_test",
|
|
created_at=time.time(),
|
|
ttl=7200,
|
|
hit_count=10,
|
|
input_tokens=500,
|
|
output_tokens=250,
|
|
embedding=[0.1, 0.2, 0.3],
|
|
)
|
|
backend.set("key1", entry)
|
|
|
|
entries = export_cache(namespace="export_test")
|
|
if entries:
|
|
e = entries[0]
|
|
assert "created_at" in e
|
|
assert "ttl" in e
|
|
assert "hit_count" in e
|
|
assert "input_tokens" in e
|
|
assert "output_tokens" in e
|