Add files via upload

initial commit
This commit is contained in:
Alpha Nerd 2026-03-06 15:54:47 +01:00 committed by GitHub
parent 8d3d5ff628
commit b33bb415dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 4840 additions and 0 deletions

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Tests for prompt-cache."""

50
tests/conftest.py Normal file
View file

@ -0,0 +1,50 @@
"""Pytest configuration and fixtures for prompt-cache tests."""
import time
import pytest
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.config import CacheConfig, CacheEntry
@pytest.fixture
def backend():
"""Provide a fresh memory backend for each test."""
return MemoryBackend()
@pytest.fixture
def cache_config():
"""Provide default cache configuration."""
return CacheConfig()
@pytest.fixture
def sample_entry():
"""Provide a sample cache entry."""
return CacheEntry(
prompt="What is Python?",
response="Python is a programming language.",
embedding=[0.1, 0.2, 0.3],
created_at=time.time(), # Use current time
ttl=3600,
namespace="default",
hit_count=0,
)
@pytest.fixture
def mock_llm_func():
"""Provide a mock LLM function."""
responses = {
"What is Python?": "Python is a programming language.",
"What's Python?": "Python is a programming language.",
"Explain Python": "Python is a high-level programming language.",
"What is Rust?": "Rust is a systems programming language.",
}
def _func(prompt: str) -> str:
return responses.get(prompt, f"Response to: {prompt}")
return _func

687
tests/test_backends.py Normal file
View file

@ -0,0 +1,687 @@
"""Tests for storage backends."""
import json
import time
from unittest.mock import MagicMock, patch
import pytest
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.backends.sqlite import SQLiteBackend # noqa: F401
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.exceptions import CacheBackendError
class TestBaseBackend:
"""Tests for BaseBackend abstract class."""
def test_cosine_similarity(self):
"""Test cosine similarity helper method."""
backend = MemoryBackend()
entry1 = CacheEntry(
prompt="test",
response="response",
embedding=[1.0, 0.0, 0.0],
)
entry2 = CacheEntry(
prompt="test",
response="response",
embedding=[1.0, 0.0, 0.0],
)
entry3 = CacheEntry(
prompt="test",
response="response",
embedding=[0.0, 1.0, 0.0],
)
# Test _find_best_match
candidates = [("key1", entry1), ("key2", entry2), ("key3", entry3)]
# Query matching entry1
result = backend._find_best_match(candidates, [1.0, 0.0, 0.0], threshold=0.9)
assert result is not None
key, entry, sim = result
assert sim == pytest.approx(1.0)
class TestMemoryBackend:
"""Tests for MemoryBackend."""
def test_set_and_get(self, backend, sample_entry):
"""Test basic set and get operations."""
backend.set("key1", sample_entry)
retrieved = backend.get("key1")
assert retrieved is not None
assert retrieved.prompt == sample_entry.prompt
assert retrieved.response == sample_entry.response
def test_get_nonexistent(self, backend):
"""Test getting non-existent key returns None."""
result = backend.get("nonexistent")
assert result is None
def test_delete(self, backend, sample_entry):
"""Test delete operation."""
backend.set("key1", sample_entry)
assert backend.get("key1") is not None
assert backend.delete("key1") is True
assert backend.get("key1") is None
def test_delete_nonexistent(self, backend):
"""Test deleting non-existent key returns False."""
assert backend.delete("nonexistent") is False
def test_clear(self, backend, sample_entry):
"""Test clear operation."""
backend.set("key1", sample_entry)
backend.set("key2", sample_entry)
backend.clear()
assert backend.get("key1") is None
assert backend.get("key2") is None
def test_iterate_all(self, backend):
"""Test iterating over all entries."""
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
results = backend.iterate()
assert len(results) == 2
def test_iterate_with_namespace(self, backend):
"""Test iterating with namespace filter."""
entry1 = CacheEntry(
prompt="p1", response="r1", namespace="ns1", created_at=time.time()
)
entry2 = CacheEntry(
prompt="p2", response="r2", namespace="ns2", created_at=time.time()
)
backend.set("key1", entry1)
backend.set("key2", entry2)
results = backend.iterate(namespace="ns1")
assert len(results) == 1
assert results[0][1].namespace == "ns1"
def test_find_similar(self, backend):
"""Test finding semantically similar entries."""
entry1 = CacheEntry(
prompt="What is Python?",
response="r1",
embedding=[1.0, 0.0, 0.0],
created_at=time.time(),
)
entry2 = CacheEntry(
prompt="What is Rust?",
response="r2",
embedding=[0.0, 1.0, 0.0],
created_at=time.time(),
)
backend.set("key1", entry1)
backend.set("key2", entry2)
# Find similar to entry1
result = backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
assert result is not None
key, entry, sim = result
assert key == "key1"
def test_find_similar_no_match(self, backend):
"""Test find_similar returns None when below threshold."""
entry = CacheEntry(
prompt="test",
response="response",
embedding=[1.0, 0.0, 0.0],
created_at=time.time(),
)
backend.set("key1", entry)
# Query with orthogonal vector
result = backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
assert result is None
def test_get_stats(self, backend):
"""Test get_stats returns correct info."""
entry = CacheEntry(
prompt="test",
response="response",
created_at=time.time(),
)
backend.set("key1", entry)
stats = backend.get_stats()
assert stats["size"] == 1
assert "hits" in stats
assert "misses" in stats
def test_hit_count_increments(self, backend, sample_entry):
"""Test hit count increments on cache hit."""
backend.set("key1", sample_entry)
backend.get("key1") # First hit
backend.get("key1") # Second hit
entry = backend.get("key1")
assert entry.hit_count >= 1
def test_lru_eviction(self):
"""Test LRU eviction when max_size is reached."""
backend = MemoryBackend(max_size=2)
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
entry3 = CacheEntry(prompt="p3", response="r3", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
assert backend.get_stats()["size"] == 2
backend.set("key3", entry3)
# Should evict oldest (key1)
assert backend.get_stats()["size"] == 2
def test_expired_entry_not_returned(self, backend):
"""Test expired entries are not returned."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=1,
created_at=time.time() - 2, # 2 seconds ago with 1s TTL
)
backend.set("key1", entry)
result = backend.get("key1")
assert result is None
class TestSQLiteBackend:
"""Tests for SQLiteBackend."""
@pytest.fixture
def sqlite_backend(self, tmp_path):
"""Create SQLite backend with temp database."""
from semantic_llm_cache.backends.sqlite import SQLiteBackend
db_path = tmp_path / "test_cache.db"
return SQLiteBackend(db_path)
def test_set_and_get(self, sqlite_backend, sample_entry):
"""Test basic set and get operations."""
sqlite_backend.set("key1", sample_entry)
retrieved = sqlite_backend.get("key1")
assert retrieved is not None
assert retrieved.prompt == sample_entry.prompt
assert retrieved.response == sample_entry.response
def test_persistence(self, sqlite_backend, sample_entry, tmp_path):
"""Test entries persist across backend instances."""
db_path = tmp_path / "test_persist.db"
# Create first instance
backend1 = SQLiteBackend(db_path)
backend1.set("key1", sample_entry)
# Create second instance (simulates restart)
backend2 = SQLiteBackend(db_path)
retrieved = backend2.get("key1")
assert retrieved is not None
assert retrieved.prompt == sample_entry.prompt
def test_get_stats(self, sqlite_backend):
"""Test get_stats returns correct info."""
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
sqlite_backend.set("key1", entry)
stats = sqlite_backend.get_stats()
assert stats["size"] == 1
assert "db_path" in stats
def test_clear(self, sqlite_backend, sample_entry):
"""Test clear operation."""
sqlite_backend.set("key1", sample_entry)
sqlite_backend.clear()
assert sqlite_backend.get("key1") is None
def test_close_and_reopen(self, sqlite_backend, sample_entry):
"""Test closing and reopening connection."""
sqlite_backend.set("key1", sample_entry)
sqlite_backend.close()
# Should be able to use after close (reopens connection)
retrieved = sqlite_backend.get("key1")
assert retrieved is not None
class TestRedisBackend:
"""Tests for RedisBackend."""
@pytest.fixture
def mock_redis(self):
"""Create mock Redis client."""
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock:
mock_client = MagicMock()
mock.from_url.return_value = mock_client
mock_client.ping.return_value = True
mock_client.get.return_value = None
mock_client.keys.return_value = []
mock_client.delete.return_value = 1
yield mock_client
@pytest.fixture
def redis_backend(self, mock_redis):
"""Create Redis backend with mocked client."""
from semantic_llm_cache.backends.redis import RedisBackend
backend = RedisBackend(url="redis://localhost:6379/0")
backend._redis = mock_redis
return backend
def test_set_and_get(self, redis_backend, mock_redis, sample_entry):
"""Test basic set and get operations."""
# Mock get to return stored data
mock_redis.get.return_value = json.dumps({
"prompt": sample_entry.prompt,
"response": sample_entry.response,
"embedding": sample_entry.embedding,
"created_at": sample_entry.created_at,
"ttl": sample_entry.ttl,
"namespace": sample_entry.namespace,
"hit_count": 0,
"input_tokens": sample_entry.input_tokens,
"output_tokens": sample_entry.output_tokens,
}).encode()
redis_backend.set("key1", sample_entry)
retrieved = redis_backend.get("key1")
assert retrieved is not None
assert retrieved.prompt == sample_entry.prompt
# set is called twice: once for initial set, once to update hit_count
assert mock_redis.set.call_count == 2
def test_get_nonexistent(self, redis_backend, mock_redis):
"""Test getting non-existent key returns None."""
mock_redis.get.return_value = None
result = redis_backend.get("nonexistent")
assert result is None
def test_delete(self, redis_backend, mock_redis, sample_entry):
"""Test delete operation."""
mock_redis.delete.return_value = 1
result = redis_backend.delete("key1")
assert result is True
mock_redis.delete.assert_called_once()
def test_delete_nonexistent(self, redis_backend, mock_redis):
"""Test deleting non-existent key returns False."""
mock_redis.delete.return_value = 0
result = redis_backend.delete("nonexistent")
assert result is False
def test_clear(self, redis_backend, mock_redis):
"""Test clear operation."""
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
redis_backend.clear()
mock_redis.delete.assert_called_once()
def test_clear_empty(self, redis_backend, mock_redis):
"""Test clear with no entries."""
mock_redis.keys.return_value = []
redis_backend.clear()
mock_redis.delete.assert_not_called()
def test_iterate_all(self, redis_backend, mock_redis, sample_entry):
"""Test iterating over all entries."""
entry_dict = {
"prompt": sample_entry.prompt,
"response": sample_entry.response,
"embedding": sample_entry.embedding,
"created_at": sample_entry.created_at,
"ttl": sample_entry.ttl,
"namespace": sample_entry.namespace,
"hit_count": 0,
"input_tokens": sample_entry.input_tokens,
"output_tokens": sample_entry.output_tokens,
}
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
mock_redis.get.return_value = json.dumps(entry_dict).encode()
results = redis_backend.iterate()
assert len(results) == 2
def test_iterate_with_namespace(self, redis_backend, mock_redis, sample_entry):
"""Test iterating with namespace filter."""
entry1_dict = {
"prompt": "p1",
"response": "r1",
"embedding": None,
"created_at": time.time(),
"ttl": None,
"namespace": "ns1",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
entry2_dict = entry1_dict.copy()
entry2_dict["namespace"] = "ns2"
call_count = [0]
def mock_get(key):
call_count[0] += 1
if call_count[0] == 1:
return json.dumps(entry1_dict).encode()
return json.dumps(entry2_dict).encode()
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
mock_redis.get.side_effect = mock_get
results = redis_backend.iterate(namespace="ns1")
assert len(results) == 1
assert results[0][1].namespace == "ns1"
def test_find_similar(self, redis_backend, mock_redis):
"""Test finding semantically similar entries."""
entry_dict = {
"prompt": "What is Python?",
"response": "r1",
"embedding": [1.0, 0.0, 0.0],
"created_at": time.time(),
"ttl": None,
"namespace": "default",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
mock_redis.get.return_value = json.dumps(entry_dict).encode()
result = redis_backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
assert result is not None
key, entry, sim = result
assert key == "key1"
def test_find_similar_no_match(self, redis_backend, mock_redis):
"""Test find_similar returns None when below threshold."""
entry_dict = {
"prompt": "test",
"response": "response",
"embedding": [1.0, 0.0, 0.0],
"created_at": time.time(),
"ttl": None,
"namespace": "default",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
mock_redis.get.return_value = json.dumps(entry_dict).encode()
# Query with orthogonal vector
result = redis_backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
assert result is None
def test_get_stats(self, redis_backend, mock_redis):
"""Test get_stats returns correct info."""
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
stats = redis_backend.get_stats()
assert "prefix" in stats
assert stats["size"] == 2
assert stats["prefix"] == "semantic_llm_cache:"
def test_get_stats_error_handling(self, redis_backend, mock_redis):
"""Test get_stats handles Redis errors gracefully."""
mock_redis.keys.side_effect = Exception("Connection lost")
stats = redis_backend.get_stats()
assert "error" in stats
assert stats["size"] == 0
def test_make_key(self, redis_backend):
"""Test key prefixing."""
result = redis_backend._make_key("test_key")
assert result == "semantic_llm_cache:test_key"
def test_entry_to_dict(self, redis_backend, sample_entry):
"""Test converting entry to dictionary."""
result = redis_backend._entry_to_dict(sample_entry)
assert result["prompt"] == sample_entry.prompt
assert result["response"] == sample_entry.response
assert result["embedding"] == sample_entry.embedding
def test_dict_to_entry(self, redis_backend):
"""Test converting dictionary to entry."""
data = {
"prompt": "test",
"response": "response",
"embedding": [1.0, 0.0],
"created_at": time.time(),
"ttl": 100,
"namespace": "test_ns",
"hit_count": 5,
"input_tokens": 100,
"output_tokens": 50,
}
entry = redis_backend._dict_to_entry(data)
assert entry.prompt == "test"
assert entry.namespace == "test_ns"
assert entry.hit_count == 5
def test_dict_to_entry_defaults(self, redis_backend):
"""Test dict_to_entry uses defaults for missing fields."""
data = {
"prompt": "test",
"response": "response",
"created_at": time.time(),
}
entry = redis_backend._dict_to_entry(data)
assert entry.embedding is None
assert entry.ttl is None
assert entry.namespace == "default"
assert entry.hit_count == 0
assert entry.input_tokens == 0
assert entry.output_tokens == 0
def test_connection_failure(self):
"""Test connection failure raises CacheBackendError."""
from semantic_llm_cache.backends.redis import RedisBackend
from semantic_llm_cache.exceptions import CacheBackendError
# Need to patch both the import and the from_url call
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock_redis:
mock_client = MagicMock()
mock_client.ping.side_effect = Exception("Connection refused")
mock_redis.from_url.return_value = mock_client
with pytest.raises(CacheBackendError, match="Failed to connect"):
RedisBackend(url="redis://localhost:9999/0")
def test_set_with_ttl(self, redis_backend, mock_redis, sample_entry):
"""Test setting entry with TTL."""
sample_entry.ttl = 3600
redis_backend.set("key1", sample_entry)
call_args = mock_redis.set.call_args
assert call_args[1]["ex"] == 3600
def test_get_expired_entry(self, redis_backend, mock_redis):
"""Test expired entry is not returned."""
expired_dict = {
"prompt": "test",
"response": "response",
"embedding": None,
"created_at": time.time() - 1000,
"ttl": 100, # 100 seconds TTL, created 1000 seconds ago
"namespace": "default",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
mock_redis.get.return_value = json.dumps(expired_dict).encode()
mock_redis.delete.return_value = 1
result = redis_backend.get("expired_key")
assert result is None
mock_redis.delete.assert_called_once()
def test_close(self, redis_backend, mock_redis):
"""Test closing Redis connection."""
redis_backend.close()
mock_redis.close.assert_called_once()
def test_close_error_handling(self, redis_backend, mock_redis):
"""Test close handles errors gracefully."""
mock_redis.close.side_effect = Exception("Close error")
# Should not raise
redis_backend.close()
def test_iterate_with_expired_entries(self, redis_backend, mock_redis):
"""Test iterate filters out expired entries."""
expired_dict = {
"prompt": "expired",
"response": "response",
"embedding": None,
"created_at": time.time() - 1000,
"ttl": 100,
"namespace": "default",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
valid_dict = {
"prompt": "valid",
"response": "response",
"embedding": None,
"created_at": time.time(),
"ttl": None,
"namespace": "default",
"hit_count": 0,
"input_tokens": 0,
"output_tokens": 0,
}
call_count = [0]
def mock_get(key):
call_count[0] += 1
if call_count[0] == 1:
return json.dumps(expired_dict).encode()
return json.dumps(valid_dict).encode()
mock_redis.keys.return_value = [b"semantic_llm_cache:expired", b"semantic_llm_cache:valid"]
mock_redis.get.side_effect = mock_get
results = redis_backend.iterate()
# Only valid entry should be returned
assert len(results) == 1
assert results[0][1].prompt == "valid"
def test_set_error_handling(self, redis_backend, mock_redis, sample_entry):
"""Test set handles Redis errors."""
from semantic_llm_cache.exceptions import CacheBackendError
mock_redis.set.side_effect = Exception("Redis error")
with pytest.raises(CacheBackendError, match="Failed to set"):
redis_backend.set("key1", sample_entry)
def test_delete_error_handling(self, redis_backend, mock_redis):
"""Test delete handles Redis errors."""
from semantic_llm_cache.exceptions import CacheBackendError
mock_redis.delete.side_effect = Exception("Redis error")
with pytest.raises(CacheBackendError, match="Failed to delete"):
redis_backend.delete("key1")
def test_iterate_error_handling(self, redis_backend, mock_redis):
"""Test iterate handles Redis errors."""
from semantic_llm_cache.exceptions import CacheBackendError
mock_redis.keys.side_effect = Exception("Redis error")
with pytest.raises(CacheBackendError, match="Failed to iterate"):
redis_backend.iterate()
def test_get_json_error(self, redis_backend, mock_redis):
"""Test get handles invalid JSON."""
import json
mock_redis.get.return_value = b"invalid json"
# The JSON decode error should be wrapped in CacheBackendError
with pytest.raises((CacheBackendError, json.JSONDecodeError)):
redis_backend.get("key1")
def test_import_error_without_package(self):
"""Test ImportError when redis package not installed."""
# This test validates the import guard in redis.py
from semantic_llm_cache.backends import redis as redis_module
# Check that RedisBackend is defined
assert hasattr(redis_module, "RedisBackend")
class TestCacheEntry:
"""Tests for CacheEntry dataclass."""
def test_is_expired_with_none_ttl(self):
"""Test entry with None TTL never expires."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=None,
created_at=time.time() - 10000,
)
assert not entry.is_expired(time.time())
def test_is_expired_with_ttl(self):
"""Test entry with TTL expires correctly."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=10,
created_at=time.time() - 15,
)
assert entry.is_expired(time.time())
def test_is_expired_not_yet(self):
"""Test entry not yet expired."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=10,
created_at=time.time() - 5,
)
assert not entry.is_expired(time.time())
def test_estimate_cost(self):
"""Test cost estimation."""
entry = CacheEntry(
prompt="test",
response="response",
input_tokens=1000,
output_tokens=500,
)
cost = entry.estimate_cost(0.001, 0.002)
# 1000/1000 * 0.001 + 500/1000 * 0.002 = 0.001 + 0.001 = 0.002
assert cost == pytest.approx(0.002)

606
tests/test_core.py Normal file
View file

@ -0,0 +1,606 @@
"""Tests for core cache decorator and API."""
import time
from unittest.mock import MagicMock
import pytest
from semantic_llm_cache import CacheContext, CachedLLM, cache, get_default_backend, set_default_backend
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.config import CacheConfig, CacheEntry
from semantic_llm_cache.exceptions import PromptCacheError
class TestCacheDecorator:
"""Tests for @cache decorator."""
def test_exact_match_cache_hit(self, mock_llm_func):
"""Test exact match caching returns cached result."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str) -> str:
call_count["count"] += 1
return mock_llm_func(prompt)
# First call - cache miss
result1 = cached_func("What is Python?")
assert result1 == "Python is a programming language."
assert call_count["count"] == 1
# Second call - cache hit
result2 = cached_func("What is Python?")
assert result2 == "Python is a programming language."
assert call_count["count"] == 1 # No additional call
def test_cache_miss_different_prompt(self, mock_llm_func):
"""Test different prompts result in cache misses."""
@cache()
def cached_func(prompt: str) -> str:
return mock_llm_func(prompt)
result1 = cached_func("What is Python?")
result2 = cached_func("What is Rust?")
assert result1 == "Python is a programming language."
assert result2 == "Rust is a systems programming language."
def test_cache_disabled(self, mock_llm_func):
"""Test caching can be disabled."""
call_count = {"count": 0}
@cache(enabled=False)
def cached_func(prompt: str) -> str:
call_count["count"] += 1
return mock_llm_func(prompt)
cached_func("What is Python?")
cached_func("What is Python?")
assert call_count["count"] == 2 # Both calls hit the function
def test_custom_namespace(self, mock_llm_func):
"""Test custom namespace isolates cache."""
@cache(namespace="test")
def cached_func(prompt: str) -> str:
return mock_llm_func(prompt)
result = cached_func("What is Python?")
assert result == "Python is a programming language."
def test_ttl_expiration(self, mock_llm_func):
"""Test TTL expiration works."""
@cache(ttl=1) # 1 second TTL
def cached_func(prompt: str) -> str:
return mock_llm_func(prompt)
# First call
result1 = cached_func("What is Python?")
assert result1 == "Python is a programming language."
# Immediate second call - should hit cache
result2 = cached_func("What is Python?")
assert result2 == "Python is a programming language."
# Wait for expiration
time.sleep(1.1)
# Third call - should miss due to TTL
# Note: This test may be flaky in slow CI environments
result3 = cached_func("What is Python?")
assert result3 == "Python is a programming language."
def test_cache_with_exception(self):
"""Test cache handles exceptions properly."""
@cache()
def failing_func(prompt: str) -> str:
raise ValueError("LLM API error")
with pytest.raises(PromptCacheError):
failing_func("test prompt")
def test_semantic_similarity_match(self, mock_llm_func):
"""Test semantic similarity matching."""
call_count = {"count": 0}
@cache(similarity=0.85)
def cached_func(prompt: str) -> str:
call_count["count"] += 1
return mock_llm_func(prompt)
# First call
cached_func("What is Python?")
assert call_count["count"] == 1
# Similar prompts may hit cache depending on embedding
# Note: With dummy embeddings, exact string matching determines similarity
cached_func("What is Python?") # Exact match
assert call_count["count"] == 1
class TestCacheContext:
"""Tests for CacheContext manager."""
def test_context_manager(self):
"""Test CacheContext works as context manager."""
with CacheContext(similarity=0.9, ttl=1800) as ctx:
assert ctx.config.similarity_threshold == 0.9
assert ctx.config.ttl == 1800
def test_context_stats(self):
"""Test context tracks stats."""
with CacheContext() as ctx:
stats = ctx.stats
assert "hits" in stats
assert "misses" in stats
class TestCachedLLM:
"""Tests for CachedLLM wrapper class."""
def test_init(self):
"""Test CachedLLM initialization."""
llm = CachedLLM(provider="openai", model="gpt-4")
assert llm._provider == "openai"
assert llm._model == "gpt-4"
def test_chat_with_llm_func(self):
"""Test chat method with custom LLM function."""
llm = CachedLLM()
def mock_llm(prompt: str) -> str:
return f"Response to: {prompt}"
result = llm.chat("Hello", llm_func=mock_llm)
assert result == "Response to: Hello"
def test_chat_caches_responses(self, mock_llm_func):
"""Test chat caches responses."""
llm = CachedLLM()
call_count = {"count": 0}
def counting_llm(prompt: str) -> str:
call_count["count"] += 1
return mock_llm_func(prompt)
llm.chat("What is Python?", llm_func=counting_llm)
llm.chat("What is Python?", llm_func=counting_llm)
# Should cache (depends on embedding, may not with dummy)
assert call_count["count"] >= 1
class TestBackendManagement:
"""Tests for backend management functions."""
def test_get_default_backend(self):
"""Test get_default_backend returns a backend."""
backend = get_default_backend()
assert backend is not None
assert isinstance(backend, MemoryBackend)
def test_set_default_backend(self):
"""Test set_default_backend changes default."""
custom_backend = MemoryBackend(max_size=10)
set_default_backend(custom_backend)
backend = get_default_backend()
assert backend is custom_backend
class TestCacheEntry:
"""Tests for CacheEntry class."""
def test_entry_creation(self):
"""Test CacheEntry creation."""
entry = CacheEntry(
prompt="test",
response="response",
created_at=time.time(),
)
assert entry.prompt == "test"
assert entry.response == "response"
def test_is_expired_no_ttl(self):
"""Test entry without TTL never expires."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=None,
created_at=time.time() - 1000,
)
assert not entry.is_expired(time.time())
def test_is_expired_with_ttl(self):
"""Test entry with TTL expires correctly."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=1, # 1 second
created_at=time.time() - 2, # 2 seconds ago
)
assert entry.is_expired(time.time())
def test_estimate_cost(self):
"""Test cost estimation."""
entry = CacheEntry(
prompt="test",
response="response",
input_tokens=100,
output_tokens=50,
)
cost = entry.estimate_cost(0.001, 0.002)
# 100/1000 * 0.001 + 50/1000 * 0.002 = 0.0001 + 0.0001 = 0.0002
assert abs(cost - 0.0002) < 1e-6
class TestCacheConfig:
"""Tests for CacheConfig class."""
def test_default_config(self):
"""Test default configuration values."""
config = CacheConfig()
assert config.similarity_threshold == 1.0
assert config.ttl == 3600
assert config.namespace == "default"
assert config.enabled is True
def test_custom_config(self):
"""Test custom configuration values."""
config = CacheConfig(
similarity_threshold=0.9,
ttl=7200,
namespace="custom",
enabled=False,
)
assert config.similarity_threshold == 0.9
assert config.ttl == 7200
assert config.namespace == "custom"
assert config.enabled is False
def test_invalid_similarity(self):
"""Test invalid similarity raises error."""
with pytest.raises(ValueError, match="similarity_threshold"):
CacheConfig(similarity_threshold=1.5)
def test_invalid_ttl(self):
"""Test invalid TTL raises error."""
with pytest.raises(ValueError, match="ttl"):
CacheConfig(ttl=-1)
def test_invalid_max_size(self):
"""Test invalid max_size raises error."""
with pytest.raises(ValueError, match="max_cache_size"):
CacheConfig(max_cache_size=0)
class TestCacheDecoratorEdgeCases:
"""Tests for edge cases in cache decorator."""
def test_cache_with_kwargs_only(self, mock_llm_func):
"""Test caching when function is called with kwargs only."""
@cache()
def cached_func(prompt: str) -> str:
return mock_llm_func(prompt)
result = cached_func(prompt="What is Python?")
assert result == "Python is a programming language."
def test_cache_with_multiple_args(self):
"""Test caching with multiple arguments."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str, temperature: float = 0.7) -> str:
call_count["count"] += 1
return f"Response to: {prompt} at {temperature}"
cached_func("test", 0.5)
cached_func("test", 0.5)
# Same prompt hits cache even with different temperature
# (cache key is based on first arg)
cached_func("test", 0.9)
cached_func("different", 0.5)
# First call + different prompt = 2 calls
assert call_count["count"] == 2
def test_cache_with_custom_key_func(self):
"""Test custom key function."""
call_count = {"count": 0}
def custom_key(prompt: str, temperature: float = 0.7) -> str:
return f"{prompt}:{temperature}"
@cache(key_func=custom_key)
def cached_func(prompt: str, temperature: float = 0.7) -> str:
call_count["count"] += 1
return f"Response to: {prompt} at {temperature}"
cached_func("test", 0.7)
cached_func("test", 0.7)
assert call_count["count"] == 1
def test_semantic_match_threshold_edge(self, mock_llm_func):
"""Test semantic matching at threshold boundaries."""
call_count = {"count": 0}
@cache(similarity=0.5) # Lower threshold
def cached_func(prompt: str) -> str:
call_count["count"] += 1
return mock_llm_func(prompt)
# First call
cached_func("What is Python?")
# Similar query may hit with lower threshold
cached_func("What is Python?") # Exact match always hits
assert call_count["count"] == 1
def test_cache_with_none_response(self):
"""Test caching None responses."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str) -> None:
call_count["count"] += 1
return None
result1 = cached_func("test")
result2 = cached_func("test")
assert result1 is None
assert result2 is None
assert call_count["count"] == 1 # Should cache
def test_cache_with_empty_string_response(self):
"""Test caching empty string responses."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str) -> str:
call_count["count"] += 1
return ""
result1 = cached_func("test")
result2 = cached_func("test")
assert result1 == ""
assert result2 == ""
assert call_count["count"] == 1 # Should cache
def test_cache_with_dict_response(self):
"""Test caching dict responses."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str) -> dict:
call_count["count"] += 1
return {"key": "value", "number": 42}
result1 = cached_func("test")
cached_func("test") # Second call to verify caching
assert result1 == {"key": "value", "number": 42}
assert call_count["count"] == 1
def test_cache_with_list_response(self):
"""Test caching list responses."""
call_count = {"count": 0}
@cache()
def cached_func(prompt: str) -> list:
call_count["count"] += 1
return [1, 2, 3, 4, 5]
result1 = cached_func("test")
cached_func("test") # Second call to verify caching
assert result1 == [1, 2, 3, 4, 5]
assert call_count["count"] == 1
class TestCacheDecoratorErrorPaths:
"""Tests for error handling in cache decorator."""
def test_backend_set_raises_error(self):
"""Test that backend.set errors propagate."""
from semantic_llm_cache.exceptions import CacheBackendError
backend = MagicMock()
# Backend set wraps exceptions in CacheBackendError
backend.set.side_effect = CacheBackendError("Storage error")
backend.get.return_value = None # No cached value
@cache(backend=backend)
def cached_func(prompt: str) -> str:
return "response"
# The CacheBackendError from backend.set should propagate
with pytest.raises(CacheBackendError, match="Storage error"):
cached_func("test")
def test_backend_get_raises_error(self):
"""Test that backend.get errors propagate."""
from semantic_llm_cache.exceptions import CacheBackendError
backend = MagicMock()
# Backend get wraps exceptions in CacheBackendError
backend.get.side_effect = CacheBackendError("Get error")
@cache(backend=backend)
def cached_func(prompt: str) -> str:
return "response"
# The CacheBackendError from backend.get should propagate
with pytest.raises(CacheBackendError, match="Get error"):
cached_func("test")
def test_llm_error_still_wrapped(self):
"""Test that LLM errors are still wrapped in PromptCacheError."""
from semantic_llm_cache.exceptions import PromptCacheError
@cache()
def failing_func(prompt: str) -> str:
raise ValueError("LLM API error")
with pytest.raises(PromptCacheError):
failing_func("test")
class TestCacheContextAdvanced:
"""Advanced tests for CacheContext."""
def test_context_with_zero_similarity(self):
"""Test context with zero similarity (accept all)."""
with CacheContext(similarity=0.0) as ctx:
assert ctx.config.similarity_threshold == 0.0
def test_context_with_infinite_ttl(self):
"""Test context with infinite TTL (None)."""
with CacheContext(ttl=None) as ctx:
assert ctx.config.ttl is None
def test_context_disabled(self):
"""Test disabled context."""
with CacheContext(enabled=False) as ctx:
assert ctx.config.enabled is False
class TestCachedLLMAdvanced:
"""Advanced tests for CachedLLM."""
def test_chat_with_kwargs(self):
"""Test chat with additional kwargs passed to LLM."""
llm = CachedLLM()
def mock_llm(prompt: str, temperature: float = 0.7, max_tokens: int = 100) -> str:
return f"Response to: {prompt} (temp={temperature}, tokens={max_tokens})"
result = llm.chat("Hello", llm_func=mock_llm, temperature=0.5, max_tokens=200)
assert "temp=0.5" in result
assert "tokens=200" in result
def test_cached_llm_with_custom_backend(self):
"""Test CachedLLM with custom backend."""
custom_backend = MemoryBackend(max_size=5)
llm = CachedLLM(backend=custom_backend)
assert llm._backend is custom_backend
def test_cached_llm_different_namespaces(self):
"""Test CachedLLM with different namespaces."""
llm1 = CachedLLM(namespace="ns1")
llm2 = CachedLLM(namespace="ns2")
assert llm1._config.namespace == "ns1"
assert llm2._config.namespace == "ns2"
class TestBackendManagementAdvanced:
"""Advanced tests for backend management."""
def test_multiple_default_backend_changes(self):
"""Test changing default backend multiple times."""
backend1 = MemoryBackend(max_size=10)
backend2 = MemoryBackend(max_size=20)
backend3 = MemoryBackend(max_size=30)
set_default_backend(backend1)
assert get_default_backend() is backend1
set_default_backend(backend2)
assert get_default_backend() is backend2
set_default_backend(backend3)
assert get_default_backend() is backend3
def test_backend_persists_stats(self):
"""Test backend stats persist across get_default_backend calls."""
backend = get_default_backend()
# Create an entry
from semantic_llm_cache.config import CacheEntry
entry = CacheEntry(
prompt="test",
response="response",
created_at=time.time()
)
backend.set("key1", entry)
# Get stats
stats = backend.get_stats()
assert stats["size"] == 1
class TestCacheEntryEdgeCases:
"""Edge case tests for CacheEntry."""
def test_entry_with_zero_tokens(self):
"""Test entry with zero token counts."""
entry = CacheEntry(
prompt="test",
response="response",
input_tokens=0,
output_tokens=0,
)
cost = entry.estimate_cost(0.001, 0.002)
assert cost == 0.0
def test_entry_with_large_token_count(self):
"""Test entry with large token counts."""
entry = CacheEntry(
prompt="test",
response="response",
input_tokens=100000,
output_tokens=50000,
)
cost = entry.estimate_cost(0.001, 0.002)
assert cost > 0
def test_entry_with_negative_ttl(self):
"""Test entry creation handles negative TTL (becomes expired)."""
entry = CacheEntry(
prompt="test",
response="response",
ttl=-1,
created_at=time.time(),
)
# Negative TTL means immediately expired
assert entry.is_expired(time.time())
def test_entry_hit_count_initialization(self):
"""Test entry initializes with zero hit count."""
entry = CacheEntry(
prompt="test",
response="response",
created_at=time.time(),
)
assert entry.hit_count == 0
class TestCacheConfigEdgeCases:
"""Edge case tests for CacheConfig."""
def test_config_boundary_similarity(self):
"""Test similarity at valid boundaries."""
config1 = CacheConfig(similarity_threshold=0.0)
assert config1.similarity_threshold == 0.0
config2 = CacheConfig(similarity_threshold=1.0)
assert config2.similarity_threshold == 1.0
def test_config_zero_ttl(self):
"""Test zero TTL is rejected (validation requires positive)."""
# The validation in CacheConfig rejects ttl <= 0
with pytest.raises(ValueError, match="ttl"):
CacheConfig(ttl=0)
def test_config_very_large_ttl(self):
"""Test very large TTL."""
config = CacheConfig(ttl=86400 * 365) # 1 year
assert config.ttl == 86400 * 365
def test_config_with_special_namespace(self):
"""Test namespace with special characters."""
config = CacheConfig(namespace="test-ns_123.v1")
assert config.namespace == "test-ns_123.v1"

330
tests/test_integration.py Normal file
View file

@ -0,0 +1,330 @@
"""Integration tests for prompt-cache."""
import time
import pytest
from semantic_llm_cache import cache, clear_cache, get_stats, invalidate
from semantic_llm_cache.backends import MemoryBackend
class TestEndToEnd:
"""End-to-end integration tests."""
def test_full_cache_workflow(self):
"""Test complete cache workflow from hit to miss."""
backend = MemoryBackend()
call_count = {"count": 0}
@cache(backend=backend)
def llm_function(prompt: str) -> str:
call_count["count"] += 1
return f"Response to: {prompt}"
# First call - miss
result1 = llm_function("What is Python?")
assert result1 == "Response to: What is Python?"
assert call_count["count"] == 1
# Second call - hit
result2 = llm_function("What is Python?")
assert result2 == "Response to: What is Python?"
assert call_count["count"] == 1
# Different prompt - miss
result3 = llm_function("What is Rust?")
assert result3 == "Response to: What is Rust?"
assert call_count["count"] == 2
def test_stats_integration(self):
"""Test statistics tracking."""
backend = MemoryBackend()
@cache(backend=backend, namespace="test")
def llm_function(prompt: str) -> str:
return f"Response to: {prompt}"
# Generate some activity
llm_function("prompt 1")
llm_function("prompt 1") # Hit
llm_function("prompt 2")
stats = get_stats(namespace="test")
assert stats["total_requests"] >= 2
def test_clear_cache_integration(self):
"""Test clearing cache affects function behavior."""
backend = MemoryBackend()
@cache(backend=backend)
def llm_function(prompt: str) -> str:
return f"Response to: {prompt}"
llm_function("test prompt")
# Clear cache
cleared = clear_cache()
assert cleared >= 0
# Function should still work
result = llm_function("test prompt")
assert result == "Response to: test prompt"
def test_invalidate_integration(self):
"""Test invalidating cache entries."""
backend = MemoryBackend()
@cache(backend=backend)
def llm_function(prompt: str) -> str:
return f"Response to: {prompt}"
llm_function("Python programming")
llm_function("Rust programming")
# Invalidate Python entries
count = invalidate("Python")
assert count >= 0
def test_multiple_namespaces(self):
"""Test cache isolation across namespaces."""
backend = MemoryBackend()
@cache(backend=backend, namespace="app1")
def app1_llm(prompt: str) -> str:
return f"App1: {prompt}"
@cache(backend=backend, namespace="app2")
def app2_llm(prompt: str) -> str:
return f"App2: {prompt}"
result1 = app1_llm("test")
result2 = app2_llm("test")
assert result1 == "App1: test"
assert result2 == "App2: test"
def test_ttl_expiration_integration(self):
"""Test TTL expiration in real workflow."""
backend = MemoryBackend()
@cache(backend=backend, ttl=1) # 1 second TTL
def llm_function(prompt: str) -> str:
return f"Response to: {prompt}"
llm_function("test prompt")
# Immediate second call - hit
llm_function("test prompt")
# Wait for expiration
time.sleep(1.5)
# Should miss (cached entry expired)
llm_function("test prompt")
class TestComplexScenarios:
"""Tests for complex real-world scenarios."""
def test_high_volume_caching(self):
"""Test cache behavior with many entries."""
backend = MemoryBackend(max_size=100)
call_count = {"count": 0}
@cache(backend=backend)
def llm_function(prompt: str) -> str:
call_count["count"] += 1
return f"Response {call_count['count']}"
# Add many entries
for i in range(150):
llm_function(f"prompt {i}")
# Some entries should have been evicted
stats = backend.get_stats()
assert stats["size"] <= 100
def test_concurrent_like_access(self):
"""Test multiple calls to same cached entry."""
backend = MemoryBackend()
@cache(backend=backend)
def llm_function(prompt: str) -> str:
return f"Unique: {time.time()}"
# Multiple calls
results = [llm_function("test") for _ in range(5)]
# All should return same result (cached)
assert len(set(results)) == 1
def test_different_return_types(self):
"""Test caching different return types."""
backend = MemoryBackend()
@cache(backend=backend)
def return_dict(prompt: str) -> dict:
return {"key": "value"}
@cache(backend=backend)
def return_list(prompt: str) -> list:
return [1, 2, 3]
@cache(backend=backend)
def return_string(prompt: str) -> str:
return "string response"
# Use unique prompts to avoid cache collision
assert isinstance(return_dict("test_dict"), dict)
assert isinstance(return_list("test_list"), list)
assert isinstance(return_string("test_string"), str)
def test_empty_and_none_responses(self):
"""Test caching empty and None responses."""
backend = MemoryBackend()
@cache(backend=backend)
def return_empty(prompt: str) -> str:
return ""
@cache(backend=backend)
def return_none(prompt: str) -> None:
return None
assert return_empty("empty_test") == ""
assert return_none("none_test") is None
# Should still cache (second calls should hit cache)
assert return_empty("empty_test") == ""
assert return_none("none_test") is None
class TestErrorHandling:
"""Tests for error handling in various scenarios."""
def test_function_with_exception(self):
"""Test function that raises exception."""
from semantic_llm_cache.exceptions import PromptCacheError
backend = MemoryBackend()
@cache(backend=backend)
def failing_function(prompt: str) -> str:
if "error" in prompt:
raise ValueError("Test error")
return "OK"
# Normal call works
assert failing_function("normal") == "OK"
# Error call raises PromptCacheError (wrapped exception)
with pytest.raises(PromptCacheError):
failing_function("error prompt")
# Normal call still works
assert failing_function("normal") == "OK"
def test_backend_error_handling(self):
"""Test that backend wraps errors properly."""
from semantic_llm_cache.backends.memory import MemoryBackend
# Use MemoryBackend which has proper error handling
backend = MemoryBackend()
@cache(backend=backend)
def working_func(prompt: str) -> str:
return f"Response to: {prompt}"
# Normal operation works
assert working_func("test") == "Response to: test"
# Second call hits cache
assert working_func("test") == "Response to: test"
# Backend properly stores and retrieves entries
stats = backend.get_stats()
assert stats["hits"] >= 1
class TestPromptNormalization:
"""Tests for prompt normalization effects."""
def test_whitespace_normalization(self):
"""Test prompts with different whitespace are cached separately."""
backend = MemoryBackend()
call_count = {"count": 0}
@cache(backend=backend)
def llm_function(prompt: str) -> str:
call_count["count"] += 1
return f"Response: {prompt}"
llm_function("What is Python?")
llm_function("What is Python?") # Extra spaces
# Normalization should make these the same
# Note: This depends on the normalization implementation
assert call_count["count"] >= 1
def test_case_sensitivity(self):
"""Test case sensitivity in caching."""
backend = MemoryBackend()
call_count = {"count": 0}
@cache(backend=backend)
def llm_function(prompt: str) -> str:
call_count["count"] += 1
return f"Response: {prompt}"
llm_function("What is Python?")
llm_function("what is python?")
# Case differences create different cache entries
# (normalization doesn't lowercase by default)
assert call_count["count"] >= 1
class TestConfigurationCombinations:
"""Tests for various configuration combinations."""
def test_no_caching_config(self):
"""Test configuration with caching disabled."""
backend = MemoryBackend()
@cache(backend=backend, enabled=False)
def llm_function(prompt: str) -> str:
return f"Response: {time.time()}"
result1 = llm_function("test")
time.sleep(0.01)
result2 = llm_function("test")
# Without caching, results differ
assert result1 != result2
def test_zero_ttl(self):
"""Test zero TTL means immediate expiration."""
backend = MemoryBackend()
@cache(backend=backend, ttl=0)
def llm_function(prompt: str) -> str:
return f"Response: {prompt}"
llm_function("test")
# Entry immediately expires, so next call is a miss
llm_function("test")
def test_infinite_ttl(self):
"""Test None TTL means never expire."""
backend = MemoryBackend()
call_count = {"count": 0}
@cache(backend=backend, ttl=None)
def llm_function(prompt: str) -> str:
call_count["count"] += 1
return f"Response: {prompt}"
llm_function("test")
llm_function("test")
assert call_count["count"] == 1

208
tests/test_similarity.py Normal file
View file

@ -0,0 +1,208 @@
"""Tests for embedding generation and similarity matching."""
import numpy as np
import pytest
from semantic_llm_cache.similarity import (
DummyEmbeddingProvider,
EmbeddingCache,
OpenAIEmbeddingProvider,
SentenceTransformerProvider,
cosine_similarity,
create_embedding_provider,
)
class TestCosineSimilarity:
"""Tests for cosine_similarity function."""
def test_identical_vectors(self):
"""Test identical vectors have similarity 1.0."""
a = [1.0, 2.0, 3.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == pytest.approx(1.0)
def test_orthogonal_vectors(self):
"""Test orthogonal vectors have similarity 0.0."""
a = [1.0, 0.0, 0.0]
b = [0.0, 1.0, 0.0]
assert cosine_similarity(a, b) == pytest.approx(0.0)
def test_opposite_vectors(self):
"""Test opposite vectors have similarity -1.0."""
a = [1.0, 2.0, 3.0]
b = [-1.0, -2.0, -3.0]
assert cosine_similarity(a, b) == pytest.approx(-1.0)
def test_zero_vectors(self):
"""Test zero vectors return 0.0."""
a = [0.0, 0.0, 0.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == 0.0
def test_numpy_array_input(self):
"""Test function accepts numpy arrays."""
a = np.array([1.0, 2.0, 3.0])
b = np.array([1.0, 2.0, 3.0])
assert cosine_similarity(a, b) == pytest.approx(1.0)
def test_mixed_dimensions(self):
"""Test vectors of different dimensions raise error."""
a = [1.0, 2.0]
b = [1.0, 2.0, 3.0]
# Should raise ValueError for mismatched dimensions
with pytest.raises(ValueError, match="dimension mismatch"):
cosine_similarity(a, b)
class TestDummyEmbeddingProvider:
"""Tests for DummyEmbeddingProvider."""
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = DummyEmbeddingProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = DummyEmbeddingProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
assert e1 == e2
def test_encode_different_inputs(self):
"""Test different inputs produce different outputs."""
provider = DummyEmbeddingProvider()
e1 = provider.encode("prompt 1")
e2 = provider.encode("prompt 2")
assert e1 != e2
def test_custom_dimension(self):
"""Test custom embedding dimension."""
provider = DummyEmbeddingProvider(dim=128)
embedding = provider.encode("test")
assert len(embedding) == 128
def test_embedding_normalized(self):
"""Test embeddings are normalized to unit length."""
provider = DummyEmbeddingProvider()
embedding = provider.encode("test prompt")
# Calculate norm
norm = np.linalg.norm(embedding)
assert norm == pytest.approx(1.0, rel=1e-5)
class TestSentenceTransformerProvider:
"""Tests for SentenceTransformerProvider."""
@pytest.mark.skip(reason="Requires sentence-transformers installation")
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = SentenceTransformerProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
@pytest.mark.skip(reason="Requires sentence-transformers installation")
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = SentenceTransformerProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
assert e1 == e2
def test_import_error_without_package(self, monkeypatch):
"""Test ImportError raised when package not installed."""
# Skip if sentence-transformers is installed
pytest.importorskip("sentence_transformers", reason="sentence-transformers is installed, cannot test import error")
class TestOpenAIEmbeddingProvider:
"""Tests for OpenAIEmbeddingProvider."""
@pytest.mark.skip(reason="Requires OpenAI API key")
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = OpenAIEmbeddingProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
@pytest.mark.skip(reason="Requires OpenAI API key")
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = OpenAIEmbeddingProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
# OpenAI embeddings may vary slightly
assert len(e1) == len(e2)
def test_import_error_without_package(self, monkeypatch):
"""Test ImportError raised when package not installed."""
# Skip if openai is installed
pytest.importorskip("openai", reason="openai is installed, cannot test import error")
class TestEmbeddingCache:
"""Tests for EmbeddingCache."""
def test_cache_provider(self):
"""Test cache uses provider."""
provider = DummyEmbeddingProvider()
cache = EmbeddingCache(provider=provider)
# Encode same text twice
e1 = cache.encode("test prompt")
e2 = cache.encode("test prompt")
assert e1 == e2
def test_cache_clear(self):
"""Test cache can be cleared."""
provider = DummyEmbeddingProvider()
cache = EmbeddingCache(provider=provider)
cache.encode("test prompt")
cache.clear_cache()
# Should still work after clear
embedding = cache.encode("test prompt")
assert len(embedding) > 0
def test_cache_default_provider(self):
"""Test cache uses dummy provider by default."""
cache = EmbeddingCache()
embedding = cache.encode("test prompt")
assert isinstance(embedding, list)
assert len(embedding) > 0
class TestCreateEmbeddingProvider:
"""Tests for create_embedding_provider factory."""
def test_create_dummy_provider(self):
"""Test creating dummy provider."""
provider = create_embedding_provider("dummy")
assert isinstance(provider, DummyEmbeddingProvider)
def test_create_auto_provider(self):
"""Test auto provider creates sentence-transformers when available."""
provider = create_embedding_provider("auto")
# Creates SentenceTransformerProvider if available, else DummyEmbeddingProvider
assert isinstance(provider, (DummyEmbeddingProvider, SentenceTransformerProvider))
def test_invalid_provider_type(self):
"""Test invalid provider type raises error."""
with pytest.raises(ValueError, match="Unknown provider type"):
create_embedding_provider("invalid_type")
def test_custom_model_name(self, monkeypatch):
"""Test custom model name is passed through."""
# This would work with actual sentence-transformers
provider = create_embedding_provider("dummy", model_name="custom-model")
assert isinstance(provider, DummyEmbeddingProvider)

443
tests/test_stats.py Normal file
View file

@ -0,0 +1,443 @@
"""Tests for statistics and analytics module."""
import time
import pytest
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.stats import (
CacheStats,
_stats_manager,
clear_cache,
export_cache,
get_stats,
invalidate,
warm_cache,
)
@pytest.fixture(autouse=True)
def clear_stats_state():
"""Clear stats state before each test."""
_stats_manager.clear_stats()
yield
class TestCacheStats:
"""Tests for CacheStats dataclass."""
def test_default_values(self):
"""Test CacheStats initializes with defaults."""
stats = CacheStats()
assert stats.hits == 0
assert stats.misses == 0
assert stats.total_saved_ms == 0.0
assert stats.estimated_savings_usd == 0.0
def test_hit_rate_empty(self):
"""Test hit rate with no requests."""
stats = CacheStats()
assert stats.hit_rate == 0.0
def test_hit_rate_all_hits(self):
"""Test hit rate with all cache hits."""
stats = CacheStats(hits=10, misses=0)
assert stats.hit_rate == 1.0
def test_hit_rate_all_misses(self):
"""Test hit rate with all cache misses."""
stats = CacheStats(hits=0, misses=10)
assert stats.hit_rate == 0.0
def test_hit_rate_mixed(self):
"""Test hit rate with mixed hits and misses."""
stats = CacheStats(hits=7, misses=3)
assert stats.hit_rate == 0.7
def test_total_requests(self):
"""Test total requests calculation."""
stats = CacheStats(hits=5, misses=3)
assert stats.total_requests == 8
def test_to_dict(self):
"""Test converting stats to dictionary."""
stats = CacheStats(hits=10, misses=5, total_saved_ms=1000.0, estimated_savings_usd=0.5)
result = stats.to_dict()
assert result["hits"] == 10
assert result["misses"] == 5
assert result["hit_rate"] == 2/3
assert result["total_requests"] == 15
assert result["total_saved_ms"] == 1000.0
assert result["estimated_savings_usd"] == 0.5
def test_iadd(self):
"""Test in-place addition of stats."""
stats1 = CacheStats(hits=5, misses=3, total_saved_ms=500.0, estimated_savings_usd=0.25)
stats2 = CacheStats(hits=3, misses=2, total_saved_ms=300.0, estimated_savings_usd=0.15)
stats1 += stats2
assert stats1.hits == 8
assert stats1.misses == 5
assert stats1.total_saved_ms == 800.0
assert stats1.estimated_savings_usd == 0.4
class TestStatsManager:
"""Tests for _StatsManager."""
def test_record_hit(self):
"""Test recording a cache hit."""
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0, saved_cost=0.01)
stats = _stats_manager.get_stats("test_ns")
assert stats.hits == 1
assert stats.total_saved_ms == 100.0
assert stats.estimated_savings_usd == 0.01
def test_record_multiple_hits(self):
"""Test recording multiple hits."""
_stats_manager.record_hit("test_ns", latency_saved_ms=50.0, saved_cost=0.005)
_stats_manager.record_hit("test_ns", latency_saved_ms=75.0, saved_cost=0.008)
stats = _stats_manager.get_stats("test_ns")
assert stats.hits == 2
assert stats.total_saved_ms == 125.0
def test_record_miss(self):
"""Test recording a cache miss."""
_stats_manager.record_miss("test_ns")
stats = _stats_manager.get_stats("test_ns")
assert stats.misses == 1
def test_get_stats_namespace(self):
"""Test getting stats for specific namespace."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_miss("ns1")
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
stats1 = _stats_manager.get_stats("ns1")
stats2 = _stats_manager.get_stats("ns2")
assert stats1.hits == 1
assert stats1.misses == 1
assert stats2.hits == 1
assert stats2.misses == 0
def test_get_stats_all_namespaces(self):
"""Test getting aggregated stats for all namespaces."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.record_miss("ns1")
stats = _stats_manager.get_stats(None) # All namespaces
assert stats.hits == 2
assert stats.misses == 1
def test_get_stats_nonexistent_namespace(self):
"""Test getting stats for namespace with no activity."""
stats = _stats_manager.get_stats("nonexistent")
assert stats.hits == 0
assert stats.misses == 0
def test_clear_stats_namespace(self):
"""Test clearing stats for specific namespace."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.clear_stats("ns1")
stats1 = _stats_manager.get_stats("ns1")
stats2 = _stats_manager.get_stats("ns2")
assert stats1.hits == 0
assert stats2.hits == 1
def test_clear_stats_all(self):
"""Test clearing all stats."""
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
_stats_manager.clear_stats()
stats = _stats_manager.get_stats(None)
assert stats.hits == 0
assert stats.misses == 0
def test_set_backend(self):
"""Test setting default backend."""
custom_backend = MemoryBackend(max_size=10)
_stats_manager.set_backend(custom_backend)
retrieved = _stats_manager.get_backend()
assert retrieved is custom_backend
class TestPublicStatsAPI:
"""Tests for public stats API functions."""
def test_get_stats(self):
"""Test get_stats returns dictionary."""
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0)
stats = get_stats("test_ns")
assert isinstance(stats, dict)
assert "hits" in stats
assert "misses" in stats
assert "hit_rate" in stats
def test_clear_cache_all(self):
"""Test clear_cache clears all entries."""
backend = _stats_manager.get_backend()
# Add some entries
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
backend.set("key1", entry)
backend.set("key2", entry)
count = clear_cache()
assert count >= 0
def test_clear_cache_namespace(self):
"""Test clear_cache clears specific namespace."""
backend = _stats_manager.get_backend()
# Add entries in different namespaces
entry1 = CacheEntry(prompt="test1", response="r1", namespace="ns1", created_at=time.time())
entry2 = CacheEntry(prompt="test2", response="r2", namespace="ns2", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
count = clear_cache(namespace="ns1")
assert count >= 0
def test_invalidate_pattern(self):
"""Test invalidating entries by pattern."""
backend = _stats_manager.get_backend()
# Add entries with different prompts
entry1 = CacheEntry(prompt="Python programming", response="r1", created_at=time.time())
entry2 = CacheEntry(prompt="Rust programming", response="r2", created_at=time.time())
entry3 = CacheEntry(prompt="JavaScript", response="r3", created_at=time.time())
backend.set("key1", entry1)
backend.set("key2", entry2)
backend.set("key3", entry3)
count = invalidate("Python")
assert count >= 0
def test_invalidate_case_insensitive(self):
"""Test invalidate is case insensitive."""
backend = _stats_manager.get_backend()
entry = CacheEntry(prompt="PYTHON programming", response="r1", created_at=time.time())
backend.set("key1", entry)
count = invalidate("python")
assert count >= 0
def test_invalidate_no_matches(self):
"""Test invalidate with no matches."""
backend = _stats_manager.get_backend()
entry = CacheEntry(prompt="Rust programming", response="r1", created_at=time.time())
backend.set("key1", entry)
count = invalidate("Python")
assert count == 0
def test_warm_cache(self):
"""Test warming cache with prompts."""
prompts = ["prompt1", "prompt2"]
def mock_llm(prompt: str) -> str:
return f"Response to: {prompt}"
count = warm_cache(prompts, mock_llm, namespace="warm_test")
assert count == len(prompts)
def test_warm_cache_with_failures(self):
"""Test warm_cache handles LLM failures gracefully."""
def failing_llm(prompt: str) -> str:
if "fail" in prompt:
raise ValueError("LLM error")
return f"Response to: {prompt}"
prompts = ["prompt1", "fail_prompt", "prompt3"]
count = warm_cache(prompts, failing_llm, namespace="warm_fail_test")
# Should return count even if some prompts fail
assert count == len(prompts)
class TestExportCache:
"""Tests for export_cache function."""
def test_export_all_entries(self, tmp_path):
"""Test exporting all cache entries."""
backend = _stats_manager.get_backend()
backend.clear()
# Add test entries
entry1 = CacheEntry(
prompt="test prompt 1",
response="response 1",
namespace="test_ns",
created_at=time.time(),
hit_count=5,
ttl=3600,
input_tokens=100,
output_tokens=50,
)
backend.set("key1", entry1)
entries = export_cache()
assert len(entries) >= 0
if entries:
assert "key" in entries[0]
assert "prompt" in entries[0]
assert "response" in entries[0]
assert "namespace" in entries[0]
assert "hit_count" in entries[0]
def test_export_namespace_filtered(self, tmp_path):
"""Test exporting entries filtered by namespace."""
backend = _stats_manager.get_backend()
backend.clear()
# Add entries in different namespaces
entry1 = CacheEntry(
prompt="test1", response="r1", namespace="ns1", created_at=time.time()
)
entry2 = CacheEntry(
prompt="test2", response="r2", namespace="ns2", created_at=time.time()
)
backend.set("key1", entry1)
backend.set("key2", entry2)
entries = export_cache(namespace="ns1")
# Should only return entries from ns1
assert all(e["namespace"] == "ns1" for e in entries)
def test_export_to_file(self, tmp_path):
"""Test exporting cache to JSON file."""
import json
filepath = tmp_path / "export.json"
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="test",
response="response",
namespace="test",
created_at=time.time(),
hit_count=3,
)
backend.set("key1", entry)
export_cache(filepath=str(filepath))
# Verify file was created and is valid JSON
assert filepath.exists()
with open(filepath) as f:
data = json.load(f)
assert isinstance(data, list)
def test_export_truncates_large_responses(self, tmp_path):
"""Test that large responses are truncated in export."""
backend = _stats_manager.get_backend()
backend.clear()
# Create entry with very large response
large_response = "x" * 2000
entry = CacheEntry(
prompt="test",
response=large_response,
created_at=time.time(),
)
backend.set("key1", entry)
entries = export_cache()
if entries:
# Response should be truncated to 1000 chars
assert len(entries[0]["response"]) <= 1000
class TestStatsIntegration:
"""Integration tests for stats with actual cache operations."""
def test_stats_tracking_with_cache_decorator(self):
"""Test that stats are tracked during cache operations."""
from semantic_llm_cache import cache
backend = MemoryBackend()
_stats_manager.clear_stats("integration_test")
@cache(backend=backend, namespace="integration_test")
def cached_func(prompt: str) -> str:
return f"Response to: {prompt}"
# Generate activity
cached_func("prompt1")
cached_func("prompt1") # Hit
cached_func("prompt2")
stats = get_stats("integration_test")
assert stats["total_requests"] >= 2
def test_cache_invalidate_integration(self):
"""Test invalidate removes entries from backend."""
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="Python is great",
response="Yes, it is!",
created_at=time.time(),
)
backend.set("key1", entry)
# Verify entry exists
assert backend.get("key1") is not None
# Invalidate
invalidate("Python")
# Entry should be gone
assert backend.get("key1") is None
def test_export_includes_metadata(self, tmp_path):
"""Test export includes all metadata fields."""
backend = _stats_manager.get_backend()
backend.clear()
entry = CacheEntry(
prompt="test prompt",
response="test response",
namespace="export_test",
created_at=time.time(),
ttl=7200,
hit_count=10,
input_tokens=500,
output_tokens=250,
embedding=[0.1, 0.2, 0.3],
)
backend.set("key1", entry)
entries = export_cache(namespace="export_test")
if entries:
e = entries[0]
assert "created_at" in e
assert "ttl" in e
assert "hit_count" in e
assert "input_tokens" in e
assert "output_tokens" in e