688 lines
24 KiB
Python
688 lines
24 KiB
Python
|
|
"""Tests for storage backends."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from semantic_llm_cache.backends import MemoryBackend
|
||
|
|
from semantic_llm_cache.backends.sqlite import SQLiteBackend # noqa: F401
|
||
|
|
from semantic_llm_cache.config import CacheEntry
|
||
|
|
from semantic_llm_cache.exceptions import CacheBackendError
|
||
|
|
|
||
|
|
|
||
|
|
class TestBaseBackend:
|
||
|
|
"""Tests for BaseBackend abstract class."""
|
||
|
|
|
||
|
|
def test_cosine_similarity(self):
|
||
|
|
"""Test cosine similarity helper method."""
|
||
|
|
backend = MemoryBackend()
|
||
|
|
|
||
|
|
entry1 = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
embedding=[1.0, 0.0, 0.0],
|
||
|
|
)
|
||
|
|
entry2 = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
embedding=[1.0, 0.0, 0.0],
|
||
|
|
)
|
||
|
|
entry3 = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
embedding=[0.0, 1.0, 0.0],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Test _find_best_match
|
||
|
|
candidates = [("key1", entry1), ("key2", entry2), ("key3", entry3)]
|
||
|
|
|
||
|
|
# Query matching entry1
|
||
|
|
result = backend._find_best_match(candidates, [1.0, 0.0, 0.0], threshold=0.9)
|
||
|
|
assert result is not None
|
||
|
|
key, entry, sim = result
|
||
|
|
assert sim == pytest.approx(1.0)
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryBackend:
|
||
|
|
"""Tests for MemoryBackend."""
|
||
|
|
|
||
|
|
def test_set_and_get(self, backend, sample_entry):
|
||
|
|
"""Test basic set and get operations."""
|
||
|
|
backend.set("key1", sample_entry)
|
||
|
|
retrieved = backend.get("key1")
|
||
|
|
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.prompt == sample_entry.prompt
|
||
|
|
assert retrieved.response == sample_entry.response
|
||
|
|
|
||
|
|
def test_get_nonexistent(self, backend):
|
||
|
|
"""Test getting non-existent key returns None."""
|
||
|
|
result = backend.get("nonexistent")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_delete(self, backend, sample_entry):
|
||
|
|
"""Test delete operation."""
|
||
|
|
backend.set("key1", sample_entry)
|
||
|
|
assert backend.get("key1") is not None
|
||
|
|
|
||
|
|
assert backend.delete("key1") is True
|
||
|
|
assert backend.get("key1") is None
|
||
|
|
|
||
|
|
def test_delete_nonexistent(self, backend):
|
||
|
|
"""Test deleting non-existent key returns False."""
|
||
|
|
assert backend.delete("nonexistent") is False
|
||
|
|
|
||
|
|
def test_clear(self, backend, sample_entry):
|
||
|
|
"""Test clear operation."""
|
||
|
|
backend.set("key1", sample_entry)
|
||
|
|
backend.set("key2", sample_entry)
|
||
|
|
backend.clear()
|
||
|
|
|
||
|
|
assert backend.get("key1") is None
|
||
|
|
assert backend.get("key2") is None
|
||
|
|
|
||
|
|
def test_iterate_all(self, backend):
|
||
|
|
"""Test iterating over all entries."""
|
||
|
|
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
|
||
|
|
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
|
||
|
|
|
||
|
|
backend.set("key1", entry1)
|
||
|
|
backend.set("key2", entry2)
|
||
|
|
|
||
|
|
results = backend.iterate()
|
||
|
|
assert len(results) == 2
|
||
|
|
|
||
|
|
def test_iterate_with_namespace(self, backend):
|
||
|
|
"""Test iterating with namespace filter."""
|
||
|
|
entry1 = CacheEntry(
|
||
|
|
prompt="p1", response="r1", namespace="ns1", created_at=time.time()
|
||
|
|
)
|
||
|
|
entry2 = CacheEntry(
|
||
|
|
prompt="p2", response="r2", namespace="ns2", created_at=time.time()
|
||
|
|
)
|
||
|
|
|
||
|
|
backend.set("key1", entry1)
|
||
|
|
backend.set("key2", entry2)
|
||
|
|
|
||
|
|
results = backend.iterate(namespace="ns1")
|
||
|
|
assert len(results) == 1
|
||
|
|
assert results[0][1].namespace == "ns1"
|
||
|
|
|
||
|
|
def test_find_similar(self, backend):
|
||
|
|
"""Test finding semantically similar entries."""
|
||
|
|
entry1 = CacheEntry(
|
||
|
|
prompt="What is Python?",
|
||
|
|
response="r1",
|
||
|
|
embedding=[1.0, 0.0, 0.0],
|
||
|
|
created_at=time.time(),
|
||
|
|
)
|
||
|
|
entry2 = CacheEntry(
|
||
|
|
prompt="What is Rust?",
|
||
|
|
response="r2",
|
||
|
|
embedding=[0.0, 1.0, 0.0],
|
||
|
|
created_at=time.time(),
|
||
|
|
)
|
||
|
|
|
||
|
|
backend.set("key1", entry1)
|
||
|
|
backend.set("key2", entry2)
|
||
|
|
|
||
|
|
# Find similar to entry1
|
||
|
|
result = backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
|
||
|
|
assert result is not None
|
||
|
|
key, entry, sim = result
|
||
|
|
assert key == "key1"
|
||
|
|
|
||
|
|
def test_find_similar_no_match(self, backend):
|
||
|
|
"""Test find_similar returns None when below threshold."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
embedding=[1.0, 0.0, 0.0],
|
||
|
|
created_at=time.time(),
|
||
|
|
)
|
||
|
|
|
||
|
|
backend.set("key1", entry)
|
||
|
|
|
||
|
|
# Query with orthogonal vector
|
||
|
|
result = backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_get_stats(self, backend):
|
||
|
|
"""Test get_stats returns correct info."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
created_at=time.time(),
|
||
|
|
)
|
||
|
|
|
||
|
|
backend.set("key1", entry)
|
||
|
|
stats = backend.get_stats()
|
||
|
|
|
||
|
|
assert stats["size"] == 1
|
||
|
|
assert "hits" in stats
|
||
|
|
assert "misses" in stats
|
||
|
|
|
||
|
|
def test_hit_count_increments(self, backend, sample_entry):
|
||
|
|
"""Test hit count increments on cache hit."""
|
||
|
|
backend.set("key1", sample_entry)
|
||
|
|
|
||
|
|
backend.get("key1") # First hit
|
||
|
|
backend.get("key1") # Second hit
|
||
|
|
|
||
|
|
entry = backend.get("key1")
|
||
|
|
assert entry.hit_count >= 1
|
||
|
|
|
||
|
|
def test_lru_eviction(self):
|
||
|
|
"""Test LRU eviction when max_size is reached."""
|
||
|
|
backend = MemoryBackend(max_size=2)
|
||
|
|
|
||
|
|
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
|
||
|
|
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
|
||
|
|
entry3 = CacheEntry(prompt="p3", response="r3", created_at=time.time())
|
||
|
|
|
||
|
|
backend.set("key1", entry1)
|
||
|
|
backend.set("key2", entry2)
|
||
|
|
assert backend.get_stats()["size"] == 2
|
||
|
|
|
||
|
|
backend.set("key3", entry3)
|
||
|
|
# Should evict oldest (key1)
|
||
|
|
assert backend.get_stats()["size"] == 2
|
||
|
|
|
||
|
|
def test_expired_entry_not_returned(self, backend):
|
||
|
|
"""Test expired entries are not returned."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
ttl=1,
|
||
|
|
created_at=time.time() - 2, # 2 seconds ago with 1s TTL
|
||
|
|
)
|
||
|
|
|
||
|
|
backend.set("key1", entry)
|
||
|
|
result = backend.get("key1")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestSQLiteBackend:
|
||
|
|
"""Tests for SQLiteBackend."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sqlite_backend(self, tmp_path):
|
||
|
|
"""Create SQLite backend with temp database."""
|
||
|
|
from semantic_llm_cache.backends.sqlite import SQLiteBackend
|
||
|
|
|
||
|
|
db_path = tmp_path / "test_cache.db"
|
||
|
|
return SQLiteBackend(db_path)
|
||
|
|
|
||
|
|
def test_set_and_get(self, sqlite_backend, sample_entry):
|
||
|
|
"""Test basic set and get operations."""
|
||
|
|
sqlite_backend.set("key1", sample_entry)
|
||
|
|
retrieved = sqlite_backend.get("key1")
|
||
|
|
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.prompt == sample_entry.prompt
|
||
|
|
assert retrieved.response == sample_entry.response
|
||
|
|
|
||
|
|
def test_persistence(self, sqlite_backend, sample_entry, tmp_path):
|
||
|
|
"""Test entries persist across backend instances."""
|
||
|
|
db_path = tmp_path / "test_persist.db"
|
||
|
|
|
||
|
|
# Create first instance
|
||
|
|
backend1 = SQLiteBackend(db_path)
|
||
|
|
backend1.set("key1", sample_entry)
|
||
|
|
|
||
|
|
# Create second instance (simulates restart)
|
||
|
|
backend2 = SQLiteBackend(db_path)
|
||
|
|
retrieved = backend2.get("key1")
|
||
|
|
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.prompt == sample_entry.prompt
|
||
|
|
|
||
|
|
def test_get_stats(self, sqlite_backend):
|
||
|
|
"""Test get_stats returns correct info."""
|
||
|
|
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
|
||
|
|
sqlite_backend.set("key1", entry)
|
||
|
|
|
||
|
|
stats = sqlite_backend.get_stats()
|
||
|
|
assert stats["size"] == 1
|
||
|
|
assert "db_path" in stats
|
||
|
|
|
||
|
|
def test_clear(self, sqlite_backend, sample_entry):
|
||
|
|
"""Test clear operation."""
|
||
|
|
sqlite_backend.set("key1", sample_entry)
|
||
|
|
sqlite_backend.clear()
|
||
|
|
|
||
|
|
assert sqlite_backend.get("key1") is None
|
||
|
|
|
||
|
|
def test_close_and_reopen(self, sqlite_backend, sample_entry):
|
||
|
|
"""Test closing and reopening connection."""
|
||
|
|
sqlite_backend.set("key1", sample_entry)
|
||
|
|
sqlite_backend.close()
|
||
|
|
|
||
|
|
# Should be able to use after close (reopens connection)
|
||
|
|
retrieved = sqlite_backend.get("key1")
|
||
|
|
assert retrieved is not None
|
||
|
|
|
||
|
|
|
||
|
|
class TestRedisBackend:
|
||
|
|
"""Tests for RedisBackend."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_redis(self):
|
||
|
|
"""Create mock Redis client."""
|
||
|
|
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock:
|
||
|
|
mock_client = MagicMock()
|
||
|
|
mock.from_url.return_value = mock_client
|
||
|
|
mock_client.ping.return_value = True
|
||
|
|
mock_client.get.return_value = None
|
||
|
|
mock_client.keys.return_value = []
|
||
|
|
mock_client.delete.return_value = 1
|
||
|
|
yield mock_client
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def redis_backend(self, mock_redis):
|
||
|
|
"""Create Redis backend with mocked client."""
|
||
|
|
from semantic_llm_cache.backends.redis import RedisBackend
|
||
|
|
|
||
|
|
backend = RedisBackend(url="redis://localhost:6379/0")
|
||
|
|
backend._redis = mock_redis
|
||
|
|
return backend
|
||
|
|
|
||
|
|
def test_set_and_get(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test basic set and get operations."""
|
||
|
|
# Mock get to return stored data
|
||
|
|
mock_redis.get.return_value = json.dumps({
|
||
|
|
"prompt": sample_entry.prompt,
|
||
|
|
"response": sample_entry.response,
|
||
|
|
"embedding": sample_entry.embedding,
|
||
|
|
"created_at": sample_entry.created_at,
|
||
|
|
"ttl": sample_entry.ttl,
|
||
|
|
"namespace": sample_entry.namespace,
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": sample_entry.input_tokens,
|
||
|
|
"output_tokens": sample_entry.output_tokens,
|
||
|
|
}).encode()
|
||
|
|
|
||
|
|
redis_backend.set("key1", sample_entry)
|
||
|
|
retrieved = redis_backend.get("key1")
|
||
|
|
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.prompt == sample_entry.prompt
|
||
|
|
# set is called twice: once for initial set, once to update hit_count
|
||
|
|
assert mock_redis.set.call_count == 2
|
||
|
|
|
||
|
|
def test_get_nonexistent(self, redis_backend, mock_redis):
|
||
|
|
"""Test getting non-existent key returns None."""
|
||
|
|
mock_redis.get.return_value = None
|
||
|
|
result = redis_backend.get("nonexistent")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_delete(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test delete operation."""
|
||
|
|
mock_redis.delete.return_value = 1
|
||
|
|
result = redis_backend.delete("key1")
|
||
|
|
assert result is True
|
||
|
|
mock_redis.delete.assert_called_once()
|
||
|
|
|
||
|
|
def test_delete_nonexistent(self, redis_backend, mock_redis):
|
||
|
|
"""Test deleting non-existent key returns False."""
|
||
|
|
mock_redis.delete.return_value = 0
|
||
|
|
result = redis_backend.delete("nonexistent")
|
||
|
|
assert result is False
|
||
|
|
|
||
|
|
def test_clear(self, redis_backend, mock_redis):
|
||
|
|
"""Test clear operation."""
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||
|
|
redis_backend.clear()
|
||
|
|
mock_redis.delete.assert_called_once()
|
||
|
|
|
||
|
|
def test_clear_empty(self, redis_backend, mock_redis):
|
||
|
|
"""Test clear with no entries."""
|
||
|
|
mock_redis.keys.return_value = []
|
||
|
|
redis_backend.clear()
|
||
|
|
mock_redis.delete.assert_not_called()
|
||
|
|
|
||
|
|
def test_iterate_all(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test iterating over all entries."""
|
||
|
|
entry_dict = {
|
||
|
|
"prompt": sample_entry.prompt,
|
||
|
|
"response": sample_entry.response,
|
||
|
|
"embedding": sample_entry.embedding,
|
||
|
|
"created_at": sample_entry.created_at,
|
||
|
|
"ttl": sample_entry.ttl,
|
||
|
|
"namespace": sample_entry.namespace,
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": sample_entry.input_tokens,
|
||
|
|
"output_tokens": sample_entry.output_tokens,
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||
|
|
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||
|
|
|
||
|
|
results = redis_backend.iterate()
|
||
|
|
assert len(results) == 2
|
||
|
|
|
||
|
|
def test_iterate_with_namespace(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test iterating with namespace filter."""
|
||
|
|
entry1_dict = {
|
||
|
|
"prompt": "p1",
|
||
|
|
"response": "r1",
|
||
|
|
"embedding": None,
|
||
|
|
"created_at": time.time(),
|
||
|
|
"ttl": None,
|
||
|
|
"namespace": "ns1",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
entry2_dict = entry1_dict.copy()
|
||
|
|
entry2_dict["namespace"] = "ns2"
|
||
|
|
|
||
|
|
call_count = [0]
|
||
|
|
|
||
|
|
def mock_get(key):
|
||
|
|
call_count[0] += 1
|
||
|
|
if call_count[0] == 1:
|
||
|
|
return json.dumps(entry1_dict).encode()
|
||
|
|
return json.dumps(entry2_dict).encode()
|
||
|
|
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||
|
|
mock_redis.get.side_effect = mock_get
|
||
|
|
|
||
|
|
results = redis_backend.iterate(namespace="ns1")
|
||
|
|
assert len(results) == 1
|
||
|
|
assert results[0][1].namespace == "ns1"
|
||
|
|
|
||
|
|
def test_find_similar(self, redis_backend, mock_redis):
|
||
|
|
"""Test finding semantically similar entries."""
|
||
|
|
entry_dict = {
|
||
|
|
"prompt": "What is Python?",
|
||
|
|
"response": "r1",
|
||
|
|
"embedding": [1.0, 0.0, 0.0],
|
||
|
|
"created_at": time.time(),
|
||
|
|
"ttl": None,
|
||
|
|
"namespace": "default",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
|
||
|
|
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||
|
|
|
||
|
|
result = redis_backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
|
||
|
|
assert result is not None
|
||
|
|
key, entry, sim = result
|
||
|
|
assert key == "key1"
|
||
|
|
|
||
|
|
def test_find_similar_no_match(self, redis_backend, mock_redis):
|
||
|
|
"""Test find_similar returns None when below threshold."""
|
||
|
|
entry_dict = {
|
||
|
|
"prompt": "test",
|
||
|
|
"response": "response",
|
||
|
|
"embedding": [1.0, 0.0, 0.0],
|
||
|
|
"created_at": time.time(),
|
||
|
|
"ttl": None,
|
||
|
|
"namespace": "default",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
|
||
|
|
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||
|
|
|
||
|
|
# Query with orthogonal vector
|
||
|
|
result = redis_backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_get_stats(self, redis_backend, mock_redis):
|
||
|
|
"""Test get_stats returns correct info."""
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||
|
|
stats = redis_backend.get_stats()
|
||
|
|
|
||
|
|
assert "prefix" in stats
|
||
|
|
assert stats["size"] == 2
|
||
|
|
assert stats["prefix"] == "semantic_llm_cache:"
|
||
|
|
|
||
|
|
def test_get_stats_error_handling(self, redis_backend, mock_redis):
|
||
|
|
"""Test get_stats handles Redis errors gracefully."""
|
||
|
|
mock_redis.keys.side_effect = Exception("Connection lost")
|
||
|
|
stats = redis_backend.get_stats()
|
||
|
|
|
||
|
|
assert "error" in stats
|
||
|
|
assert stats["size"] == 0
|
||
|
|
|
||
|
|
def test_make_key(self, redis_backend):
|
||
|
|
"""Test key prefixing."""
|
||
|
|
result = redis_backend._make_key("test_key")
|
||
|
|
assert result == "semantic_llm_cache:test_key"
|
||
|
|
|
||
|
|
def test_entry_to_dict(self, redis_backend, sample_entry):
|
||
|
|
"""Test converting entry to dictionary."""
|
||
|
|
result = redis_backend._entry_to_dict(sample_entry)
|
||
|
|
assert result["prompt"] == sample_entry.prompt
|
||
|
|
assert result["response"] == sample_entry.response
|
||
|
|
assert result["embedding"] == sample_entry.embedding
|
||
|
|
|
||
|
|
def test_dict_to_entry(self, redis_backend):
|
||
|
|
"""Test converting dictionary to entry."""
|
||
|
|
data = {
|
||
|
|
"prompt": "test",
|
||
|
|
"response": "response",
|
||
|
|
"embedding": [1.0, 0.0],
|
||
|
|
"created_at": time.time(),
|
||
|
|
"ttl": 100,
|
||
|
|
"namespace": "test_ns",
|
||
|
|
"hit_count": 5,
|
||
|
|
"input_tokens": 100,
|
||
|
|
"output_tokens": 50,
|
||
|
|
}
|
||
|
|
|
||
|
|
entry = redis_backend._dict_to_entry(data)
|
||
|
|
assert entry.prompt == "test"
|
||
|
|
assert entry.namespace == "test_ns"
|
||
|
|
assert entry.hit_count == 5
|
||
|
|
|
||
|
|
def test_dict_to_entry_defaults(self, redis_backend):
|
||
|
|
"""Test dict_to_entry uses defaults for missing fields."""
|
||
|
|
data = {
|
||
|
|
"prompt": "test",
|
||
|
|
"response": "response",
|
||
|
|
"created_at": time.time(),
|
||
|
|
}
|
||
|
|
|
||
|
|
entry = redis_backend._dict_to_entry(data)
|
||
|
|
assert entry.embedding is None
|
||
|
|
assert entry.ttl is None
|
||
|
|
assert entry.namespace == "default"
|
||
|
|
assert entry.hit_count == 0
|
||
|
|
assert entry.input_tokens == 0
|
||
|
|
assert entry.output_tokens == 0
|
||
|
|
|
||
|
|
def test_connection_failure(self):
|
||
|
|
"""Test connection failure raises CacheBackendError."""
|
||
|
|
from semantic_llm_cache.backends.redis import RedisBackend
|
||
|
|
from semantic_llm_cache.exceptions import CacheBackendError
|
||
|
|
|
||
|
|
# Need to patch both the import and the from_url call
|
||
|
|
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock_redis:
|
||
|
|
mock_client = MagicMock()
|
||
|
|
mock_client.ping.side_effect = Exception("Connection refused")
|
||
|
|
mock_redis.from_url.return_value = mock_client
|
||
|
|
|
||
|
|
with pytest.raises(CacheBackendError, match="Failed to connect"):
|
||
|
|
RedisBackend(url="redis://localhost:9999/0")
|
||
|
|
|
||
|
|
def test_set_with_ttl(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test setting entry with TTL."""
|
||
|
|
sample_entry.ttl = 3600
|
||
|
|
redis_backend.set("key1", sample_entry)
|
||
|
|
|
||
|
|
call_args = mock_redis.set.call_args
|
||
|
|
assert call_args[1]["ex"] == 3600
|
||
|
|
|
||
|
|
def test_get_expired_entry(self, redis_backend, mock_redis):
|
||
|
|
"""Test expired entry is not returned."""
|
||
|
|
expired_dict = {
|
||
|
|
"prompt": "test",
|
||
|
|
"response": "response",
|
||
|
|
"embedding": None,
|
||
|
|
"created_at": time.time() - 1000,
|
||
|
|
"ttl": 100, # 100 seconds TTL, created 1000 seconds ago
|
||
|
|
"namespace": "default",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_redis.get.return_value = json.dumps(expired_dict).encode()
|
||
|
|
mock_redis.delete.return_value = 1
|
||
|
|
|
||
|
|
result = redis_backend.get("expired_key")
|
||
|
|
assert result is None
|
||
|
|
mock_redis.delete.assert_called_once()
|
||
|
|
|
||
|
|
def test_close(self, redis_backend, mock_redis):
|
||
|
|
"""Test closing Redis connection."""
|
||
|
|
redis_backend.close()
|
||
|
|
mock_redis.close.assert_called_once()
|
||
|
|
|
||
|
|
def test_close_error_handling(self, redis_backend, mock_redis):
|
||
|
|
"""Test close handles errors gracefully."""
|
||
|
|
mock_redis.close.side_effect = Exception("Close error")
|
||
|
|
# Should not raise
|
||
|
|
redis_backend.close()
|
||
|
|
|
||
|
|
def test_iterate_with_expired_entries(self, redis_backend, mock_redis):
|
||
|
|
"""Test iterate filters out expired entries."""
|
||
|
|
expired_dict = {
|
||
|
|
"prompt": "expired",
|
||
|
|
"response": "response",
|
||
|
|
"embedding": None,
|
||
|
|
"created_at": time.time() - 1000,
|
||
|
|
"ttl": 100,
|
||
|
|
"namespace": "default",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
valid_dict = {
|
||
|
|
"prompt": "valid",
|
||
|
|
"response": "response",
|
||
|
|
"embedding": None,
|
||
|
|
"created_at": time.time(),
|
||
|
|
"ttl": None,
|
||
|
|
"namespace": "default",
|
||
|
|
"hit_count": 0,
|
||
|
|
"input_tokens": 0,
|
||
|
|
"output_tokens": 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
call_count = [0]
|
||
|
|
|
||
|
|
def mock_get(key):
|
||
|
|
call_count[0] += 1
|
||
|
|
if call_count[0] == 1:
|
||
|
|
return json.dumps(expired_dict).encode()
|
||
|
|
return json.dumps(valid_dict).encode()
|
||
|
|
|
||
|
|
mock_redis.keys.return_value = [b"semantic_llm_cache:expired", b"semantic_llm_cache:valid"]
|
||
|
|
mock_redis.get.side_effect = mock_get
|
||
|
|
|
||
|
|
results = redis_backend.iterate()
|
||
|
|
# Only valid entry should be returned
|
||
|
|
assert len(results) == 1
|
||
|
|
assert results[0][1].prompt == "valid"
|
||
|
|
|
||
|
|
def test_set_error_handling(self, redis_backend, mock_redis, sample_entry):
|
||
|
|
"""Test set handles Redis errors."""
|
||
|
|
from semantic_llm_cache.exceptions import CacheBackendError
|
||
|
|
|
||
|
|
mock_redis.set.side_effect = Exception("Redis error")
|
||
|
|
|
||
|
|
with pytest.raises(CacheBackendError, match="Failed to set"):
|
||
|
|
redis_backend.set("key1", sample_entry)
|
||
|
|
|
||
|
|
def test_delete_error_handling(self, redis_backend, mock_redis):
|
||
|
|
"""Test delete handles Redis errors."""
|
||
|
|
from semantic_llm_cache.exceptions import CacheBackendError
|
||
|
|
|
||
|
|
mock_redis.delete.side_effect = Exception("Redis error")
|
||
|
|
|
||
|
|
with pytest.raises(CacheBackendError, match="Failed to delete"):
|
||
|
|
redis_backend.delete("key1")
|
||
|
|
|
||
|
|
def test_iterate_error_handling(self, redis_backend, mock_redis):
|
||
|
|
"""Test iterate handles Redis errors."""
|
||
|
|
from semantic_llm_cache.exceptions import CacheBackendError
|
||
|
|
|
||
|
|
mock_redis.keys.side_effect = Exception("Redis error")
|
||
|
|
|
||
|
|
with pytest.raises(CacheBackendError, match="Failed to iterate"):
|
||
|
|
redis_backend.iterate()
|
||
|
|
|
||
|
|
def test_get_json_error(self, redis_backend, mock_redis):
|
||
|
|
"""Test get handles invalid JSON."""
|
||
|
|
import json
|
||
|
|
mock_redis.get.return_value = b"invalid json"
|
||
|
|
|
||
|
|
# The JSON decode error should be wrapped in CacheBackendError
|
||
|
|
with pytest.raises((CacheBackendError, json.JSONDecodeError)):
|
||
|
|
redis_backend.get("key1")
|
||
|
|
|
||
|
|
def test_import_error_without_package(self):
|
||
|
|
"""Test ImportError when redis package not installed."""
|
||
|
|
# This test validates the import guard in redis.py
|
||
|
|
from semantic_llm_cache.backends import redis as redis_module
|
||
|
|
|
||
|
|
# Check that RedisBackend is defined
|
||
|
|
assert hasattr(redis_module, "RedisBackend")
|
||
|
|
|
||
|
|
|
||
|
|
class TestCacheEntry:
|
||
|
|
"""Tests for CacheEntry dataclass."""
|
||
|
|
|
||
|
|
def test_is_expired_with_none_ttl(self):
|
||
|
|
"""Test entry with None TTL never expires."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
ttl=None,
|
||
|
|
created_at=time.time() - 10000,
|
||
|
|
)
|
||
|
|
assert not entry.is_expired(time.time())
|
||
|
|
|
||
|
|
def test_is_expired_with_ttl(self):
|
||
|
|
"""Test entry with TTL expires correctly."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
ttl=10,
|
||
|
|
created_at=time.time() - 15,
|
||
|
|
)
|
||
|
|
assert entry.is_expired(time.time())
|
||
|
|
|
||
|
|
def test_is_expired_not_yet(self):
|
||
|
|
"""Test entry not yet expired."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
ttl=10,
|
||
|
|
created_at=time.time() - 5,
|
||
|
|
)
|
||
|
|
assert not entry.is_expired(time.time())
|
||
|
|
|
||
|
|
def test_estimate_cost(self):
|
||
|
|
"""Test cost estimation."""
|
||
|
|
entry = CacheEntry(
|
||
|
|
prompt="test",
|
||
|
|
response="response",
|
||
|
|
input_tokens=1000,
|
||
|
|
output_tokens=500,
|
||
|
|
)
|
||
|
|
cost = entry.estimate_cost(0.001, 0.002)
|
||
|
|
# 1000/1000 * 0.001 + 500/1000 * 0.002 = 0.001 + 0.001 = 0.002
|
||
|
|
assert cost == pytest.approx(0.002)
|