Add files via upload
initial commit
This commit is contained in:
parent
8d3d5ff628
commit
b33bb415dd
24 changed files with 4840 additions and 0 deletions
46
Makefile
Normal file
46
Makefile
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
.PHONY: help install install-dev test lint format clean build publish
|
||||
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@echo " make install - Install package"
|
||||
@echo " make install-dev - Install with dev dependencies"
|
||||
@echo " make test - Run tests with coverage"
|
||||
@echo " make lint - Run linting (ruff, black, mypy)"
|
||||
@echo " make format - Format code with black and ruff"
|
||||
@echo " make clean - Clean build artifacts"
|
||||
@echo " make build - Build distribution packages"
|
||||
@echo " make publish - Publish to PyPI"
|
||||
|
||||
install:
|
||||
pip install -e .
|
||||
|
||||
install-dev:
|
||||
pip install -e ".[dev,all]"
|
||||
|
||||
test:
|
||||
pytest
|
||||
|
||||
lint:
|
||||
ruff check .
|
||||
black --check .
|
||||
mypy .
|
||||
|
||||
format:
|
||||
ruff check --fix .
|
||||
black .
|
||||
|
||||
clean:
|
||||
rm -rf build/
|
||||
rm -rf dist/
|
||||
rm -rf *.egg-info/
|
||||
rm -rf .pytest_cache/
|
||||
rm -rf htmlcov/
|
||||
rm -rf .coverage
|
||||
find . -type d -name __pycache__ -exec rm -rf {} +
|
||||
find . -type f -name "*.pyc" -delete
|
||||
|
||||
build: clean
|
||||
python -m build
|
||||
|
||||
publish: build
|
||||
twine upload dist/*
|
||||
279
README.md
Normal file
279
README.md
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
# semantic-llm-cache
|
||||
|
||||
**Async semantic caching for LLM API calls — reduce costs with one decorator.**
|
||||
|
||||
[](https://pypi.org/project/semantic-llm-cache/)
|
||||
[](LICENSE)
|
||||
[](https://pypi.org/project/semantic-llm-cache/)
|
||||
|
||||
> **Fork of [karthyick/prompt-cache](https://github.com/karthyick/prompt-cache)** — fully converted to async for use with async frameworks (FastAPI, aiohttp, Starlette, etc.).
|
||||
|
||||
## Overview
|
||||
|
||||
LLM API calls are expensive and slow. In production applications, **20-40% of prompts are semantically identical** but get charged as separate API calls. `semantic-llm-cache` solves this with a simple decorator that:
|
||||
|
||||
- ✅ **Caches semantically similar prompts** (not just exact matches)
|
||||
- ✅ **Reduces API costs by 20-40%**
|
||||
- ✅ **Returns cached responses in <10ms**
|
||||
- ✅ **Works with any LLM provider** (OpenAI, Anthropic, Ollama, local models)
|
||||
- ✅ **Fully async** — native `async/await` throughout, no event loop blocking
|
||||
- ✅ **Auto-detects** sync vs async decorated functions — one decorator for both
|
||||
|
||||
## What changed from the original
|
||||
|
||||
| Area | Original | This fork |
|
||||
| -------------------- | ------------------------- | ------------------------------------------------------------------- |
|
||||
| Backends | sync (`sqlite3`, `redis`) | async (`aiosqlite`, `redis.asyncio`) |
|
||||
| `@cache` decorator | sync only | auto-detects async/sync |
|
||||
| `EmbeddingCache` | sync `encode()` | adds `async aencode()` via `asyncio.to_thread` |
|
||||
| `CacheContext` | sync only | supports both `with` and `async with` |
|
||||
| `CachedLLM` | `chat()` | adds `achat()` |
|
||||
| Utility functions | sync | `clear_cache`, `invalidate`, `warm_cache`, `export_cache` all async |
|
||||
| `StorageBackend` ABC | sync abstract methods | all abstract methods are `async def` |
|
||||
| Min Python | 3.9 | 3.10 (uses `X \| Y` union syntax) |
|
||||
|
||||
## Installation
|
||||
|
||||
Not yet published to PyPI. Install directly from the repository:
|
||||
|
||||
```bash
|
||||
# Clone
|
||||
git clone https://github.com/YOUR_ORG/prompt-cache.git
|
||||
cd prompt-cache
|
||||
|
||||
# Core (exact match only, SQLite backend)
|
||||
pip install .
|
||||
|
||||
# With semantic similarity (sentence-transformers)
|
||||
pip install ".[semantic]"
|
||||
|
||||
# With Redis backend
|
||||
pip install ".[redis]"
|
||||
|
||||
# Everything
|
||||
pip install ".[all]"
|
||||
```
|
||||
|
||||
Or install directly via pip from git:
|
||||
|
||||
```bash
|
||||
pip install "git+https://github.com/nomyo-ai/.git"
|
||||
pip install "git+https://github.com/nomyo-ai/async-semantic-llm-cache.git[semantic]"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Async function (FastAPI, aiohttp, etc.)
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import cache
|
||||
|
||||
@cache(similarity=0.95, ttl=3600)
|
||||
async def ask_llm(prompt: str) -> str:
|
||||
return await call_ollama(prompt)
|
||||
|
||||
# First call — LLM hit
|
||||
await ask_llm("What is Python?")
|
||||
|
||||
# Second call — cache hit (<10ms, free)
|
||||
await ask_llm("What's Python?") # 95% similar → cache hit
|
||||
```
|
||||
|
||||
### Sync function (backwards compatible)
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import cache
|
||||
|
||||
@cache()
|
||||
def ask_llm_sync(prompt: str) -> str:
|
||||
return call_openai(prompt) # works, but don't use inside a running event loop
|
||||
```
|
||||
|
||||
### Semantic Matching
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import cache
|
||||
|
||||
@cache(similarity=0.90)
|
||||
async def ask_llm(prompt: str) -> str:
|
||||
return await call_ollama(prompt)
|
||||
|
||||
await ask_llm("What is Python?") # LLM call
|
||||
await ask_llm("What's Python?") # cache hit (95% similar)
|
||||
await ask_llm("Explain Python") # cache hit (91% similar)
|
||||
await ask_llm("What is Rust?") # LLM call (different topic)
|
||||
```
|
||||
|
||||
### SQLite backend (default, persistent)
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import cache
|
||||
from semantic_llm_cache.backends import SQLiteBackend
|
||||
|
||||
backend = SQLiteBackend(db_path="my_cache.db")
|
||||
|
||||
@cache(backend=backend, similarity=0.95)
|
||||
async def ask_llm(prompt: str) -> str:
|
||||
return await call_ollama(prompt)
|
||||
```
|
||||
|
||||
### Redis backend (distributed)
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import cache
|
||||
from semantic_llm_cache.backends import RedisBackend
|
||||
|
||||
backend = RedisBackend(url="redis://localhost:6379/0")
|
||||
await backend.ping() # verify connection (replaces __init__ connection test)
|
||||
|
||||
@cache(backend=backend, similarity=0.95)
|
||||
async def ask_llm(prompt: str) -> str:
|
||||
return await call_ollama(prompt)
|
||||
```
|
||||
|
||||
### Cache Statistics
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import get_stats
|
||||
|
||||
stats = get_stats()
|
||||
# {
|
||||
# "hits": 1547,
|
||||
# "misses": 892,
|
||||
# "hit_rate": 0.634,
|
||||
# "estimated_savings_usd": 3.09,
|
||||
# "total_saved_ms": 773500
|
||||
# }
|
||||
```
|
||||
|
||||
### Cache Management
|
||||
|
||||
```python
|
||||
from semantic_llm_cache.stats import clear_cache, invalidate
|
||||
|
||||
# Clear all cached entries
|
||||
await clear_cache()
|
||||
|
||||
# Invalidate entries matching a pattern
|
||||
await invalidate(pattern="Python")
|
||||
```
|
||||
|
||||
### Async context manager
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import CacheContext
|
||||
|
||||
async with CacheContext(similarity=0.9) as ctx:
|
||||
result1 = await any_cached_llm_call("prompt 1")
|
||||
result2 = await any_cached_llm_call("prompt 2")
|
||||
|
||||
print(ctx.stats) # {"hits": 1, "misses": 1}
|
||||
```
|
||||
|
||||
### CachedLLM wrapper
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import CachedLLM
|
||||
|
||||
llm = CachedLLM(similarity=0.9, ttl=3600)
|
||||
response = await llm.achat("What is Python?", llm_func=my_async_llm)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `@cache()` Decorator
|
||||
|
||||
```python
|
||||
@cache(
|
||||
similarity: float = 1.0, # 1.0 = exact match, 0.9 = semantic
|
||||
ttl: int = 3600, # seconds, None = forever
|
||||
backend: Backend = None, # None = in-memory
|
||||
namespace: str = "default", # isolate different use cases
|
||||
enabled: bool = True, # toggle for debugging
|
||||
key_func: Callable = None, # custom cache key
|
||||
)
|
||||
async def my_llm_function(prompt: str) -> str:
|
||||
...
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------ | ------------- | ----------- | --------------------------------------------------------- |
|
||||
| `similarity` | `float` | `1.0` | Cosine similarity threshold (1.0 = exact, 0.9 = semantic) |
|
||||
| `ttl` | `int \| None` | `3600` | Time-to-live in seconds (None = never expires) |
|
||||
| `backend` | `Backend` | `None` | Storage backend (None = in-memory) |
|
||||
| `namespace` | `str` | `"default"` | Isolate different use cases |
|
||||
| `enabled` | `bool` | `True` | Enable/disable caching |
|
||||
| `key_func` | `Callable` | `None` | Custom cache key function |
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```python
|
||||
from semantic_llm_cache import get_stats # sync — safe anywhere
|
||||
from semantic_llm_cache.stats import (
|
||||
clear_cache, # async
|
||||
invalidate, # async
|
||||
warm_cache, # async
|
||||
export_cache, # async
|
||||
)
|
||||
```
|
||||
|
||||
## Backends
|
||||
|
||||
| Backend | Description | I/O |
|
||||
| --------------- | ------------------------------------ | ------------------------- |
|
||||
| `MemoryBackend` | In-memory LRU (default) | none — runs in event loop |
|
||||
| `SQLiteBackend` | Persistent, file-based (`aiosqlite`) | async non-blocking |
|
||||
| `RedisBackend` | Distributed (`redis.asyncio`) | async non-blocking |
|
||||
|
||||
## Embedding Providers
|
||||
|
||||
| Provider | Quality | Notes |
|
||||
| ----------------------------- | ---------------------------- | --------------------------- |
|
||||
| `DummyEmbeddingProvider` | hash-only, no semantic match | zero deps, default |
|
||||
| `SentenceTransformerProvider` | high (local model) | requires `[semantic]` extra |
|
||||
| `OpenAIEmbeddingProvider` | high (API) | requires `[openai]` extra |
|
||||
|
||||
Embedding inference is offloaded via `asyncio.to_thread` — model loading is blocking and should be done at application startup, not on first request.
|
||||
|
||||
```python
|
||||
from semantic_llm_cache.similarity import create_embedding_provider, EmbeddingCache
|
||||
|
||||
# Pre-load at startup (blocking — do this in lifespan, not a request handler)
|
||||
provider = create_embedding_provider("sentence-transformer")
|
||||
embedding_cache = EmbeddingCache(provider=provider)
|
||||
|
||||
# Use in request handlers (non-blocking)
|
||||
embedding = await embedding_cache.aencode("my prompt")
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
| Metric | Value |
|
||||
| -------------------------- | ---------------------------------------- |
|
||||
| Cache hit latency | <10ms |
|
||||
| Embedding overhead on miss | ~50ms (sentence-transformers, offloaded) |
|
||||
| Typical hit rate | 25-40% |
|
||||
| Cost reduction | 20-40% |
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python >= 3.10
|
||||
- numpy >= 1.24.0
|
||||
- aiosqlite >= 0.19.0
|
||||
|
||||
### Optional
|
||||
|
||||
- `sentence-transformers >= 2.2.0` — semantic matching
|
||||
- `redis >= 4.2.0` — Redis backend (includes `redis.asyncio`)
|
||||
- `openai >= 1.0.0` — OpenAI embeddings
|
||||
|
||||
## License
|
||||
|
||||
MIT — see [LICENSE](LICENSE).
|
||||
|
||||
## Credits
|
||||
|
||||
Original library by **Karthick Raja M** ([@karthyick](https://github.com/karthyick)).
|
||||
Async conversion by this fork.
|
||||
112
pyproject.toml
Normal file
112
pyproject.toml
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "semantic-llm-cache"
|
||||
version = "0.2.0"
|
||||
description = "Async semantic caching for LLM API calls - reduce costs with one decorator"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "Karthick Raja M", email = "karthickrajam18@gmail.com"}
|
||||
]
|
||||
keywords = [
|
||||
"llm",
|
||||
"cache",
|
||||
"semantic",
|
||||
"async",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"ollama",
|
||||
"prompt",
|
||||
"optimization",
|
||||
"cost-reduction",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Framework :: AsyncIO",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"numpy>=1.24.0",
|
||||
"aiosqlite>=0.19.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
semantic = [
|
||||
"sentence-transformers>=2.2.0",
|
||||
]
|
||||
redis = [
|
||||
"redis>=4.2.0",
|
||||
]
|
||||
openai = [
|
||||
"openai>=1.0.0",
|
||||
]
|
||||
all = [
|
||||
"sentence-transformers>=2.2.0",
|
||||
"redis>=4.2.0",
|
||||
"openai>=1.0.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"black>=23.0.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.0.0",
|
||||
"build>=1.0.0",
|
||||
"twine>=4.0.0",
|
||||
"bump2version>=1.0.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/karthyick/prompt-cache"
|
||||
Documentation = "https://github.com/karthyick/prompt-cache#readme"
|
||||
Repository = "https://github.com/karthyick/prompt-cache.git"
|
||||
"Bug Tracker" = "https://github.com/karthyick/prompt-cache/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["semantic_llm_cache*"]
|
||||
exclude = ["tests*", "examples*", "docs*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
semantic_llm_cache = ["py.typed"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py310', 'py311', 'py312', 'py313']
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP", "B", "C4"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disable_error_code = ["annotation-unchecked"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
asyncio_mode = "auto"
|
||||
addopts = "-v --cov=semantic_llm_cache --cov-report=html --cov-report=term-missing --cov-fail-under=90"
|
||||
54
semantic_llm_cache/__init__.py
Normal file
54
semantic_llm_cache/__init__.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
llm-semantic-cache: Semantic caching for LLM API calls.
|
||||
|
||||
Cut LLM costs 30% with one decorator.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "Karthick Raja M"
|
||||
__license__ = "MIT"
|
||||
|
||||
# Core exports
|
||||
from semantic_llm_cache.config import CacheConfig
|
||||
from semantic_llm_cache.core import (
|
||||
CacheContext,
|
||||
CachedLLM,
|
||||
cache,
|
||||
get_default_backend,
|
||||
set_default_backend,
|
||||
)
|
||||
from semantic_llm_cache.exceptions import (
|
||||
CacheBackendError,
|
||||
CacheNotFoundError,
|
||||
CacheSerializationError,
|
||||
PromptCacheError,
|
||||
)
|
||||
from semantic_llm_cache.stats import CacheStats, clear_cache, get_stats, invalidate
|
||||
from semantic_llm_cache.storage import StorageBackend
|
||||
|
||||
__all__ = [
|
||||
# Version info
|
||||
"__version__",
|
||||
"__author__",
|
||||
"__license__",
|
||||
# Core API
|
||||
"cache",
|
||||
"CacheContext",
|
||||
"CachedLLM",
|
||||
"get_default_backend",
|
||||
"set_default_backend",
|
||||
# Storage
|
||||
"StorageBackend",
|
||||
# Statistics
|
||||
"CacheStats",
|
||||
"get_stats",
|
||||
"clear_cache",
|
||||
"invalidate",
|
||||
# Configuration
|
||||
"CacheConfig",
|
||||
# Exceptions
|
||||
"PromptCacheError",
|
||||
"CacheBackendError",
|
||||
"CacheSerializationError",
|
||||
"CacheNotFoundError",
|
||||
]
|
||||
21
semantic_llm_cache/backends/__init__.py
Normal file
21
semantic_llm_cache/backends/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Storage backends for llm-semantic-cache."""
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.backends.memory import MemoryBackend
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.sqlite import SQLiteBackend
|
||||
except ImportError:
|
||||
SQLiteBackend = None # type: ignore
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.redis import RedisBackend
|
||||
except ImportError:
|
||||
RedisBackend = None # type: ignore
|
||||
|
||||
__all__ = [
|
||||
"BaseBackend",
|
||||
"MemoryBackend",
|
||||
"SQLiteBackend",
|
||||
"RedisBackend",
|
||||
]
|
||||
104
semantic_llm_cache/backends/base.py
Normal file
104
semantic_llm_cache/backends/base.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Base backend implementation with common functionality."""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.storage import StorageBackend
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
class BaseBackend(StorageBackend):
|
||||
"""Base backend with common sync helpers; async public interface via StorageBackend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize base backend."""
|
||||
self._hits: int = 0
|
||||
self._misses: int = 0
|
||||
|
||||
def _increment_hits(self) -> None:
|
||||
"""Increment hit counter."""
|
||||
self._hits += 1
|
||||
|
||||
def _increment_misses(self) -> None:
|
||||
"""Increment miss counter."""
|
||||
self._misses += 1
|
||||
|
||||
def _check_expired(self, entry: CacheEntry) -> bool:
|
||||
"""Check if entry is expired.
|
||||
|
||||
Args:
|
||||
entry: CacheEntry to check
|
||||
|
||||
Returns:
|
||||
True if expired, False otherwise
|
||||
"""
|
||||
return entry.is_expired(time.time())
|
||||
|
||||
def _find_best_match(
|
||||
self,
|
||||
candidates: list[tuple[str, CacheEntry]],
|
||||
query_embedding: list[float],
|
||||
threshold: float,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find best matching entry from candidates.
|
||||
|
||||
Sync helper — CPU-only numpy ops, safe to call from async context.
|
||||
|
||||
Args:
|
||||
candidates: List of (key, entry) tuples
|
||||
query_embedding: Query embedding vector
|
||||
threshold: Minimum similarity threshold
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
best_match: Optional[tuple[str, CacheEntry, float]] = None
|
||||
best_similarity = threshold
|
||||
|
||||
for key, entry in candidates:
|
||||
if entry.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(query_embedding, entry.embedding)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = (key, entry, similarity)
|
||||
|
||||
return best_match
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with hits and misses
|
||||
"""
|
||||
return {
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": self._hits / max(self._hits + self._misses, 1),
|
||||
}
|
||||
179
semantic_llm_cache/backends/memory.py
Normal file
179
semantic_llm_cache/backends/memory.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""In-memory storage backend."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class MemoryBackend(BaseBackend):
|
||||
"""In-memory cache storage with LRU eviction.
|
||||
|
||||
All operations are in-memory dict access — no I/O — so async methods
|
||||
run directly in the event loop without thread offloading.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: Optional[int] = None) -> None:
|
||||
"""Initialize memory backend.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store (LRU eviction when reached)
|
||||
"""
|
||||
super().__init__()
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._access_order: dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._access_counter: float = 0.0
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entry if at capacity."""
|
||||
if self._max_size is None or len(self._cache) < self._max_size:
|
||||
return
|
||||
|
||||
if self._access_order:
|
||||
lru_key = min(self._access_order, key=lambda k: self._access_order.get(k, 0))
|
||||
del self._cache[lru_key]
|
||||
del self._access_order[lru_key]
|
||||
|
||||
def _update_access_time(self, key: str) -> None:
|
||||
"""Update access time for LRU tracking."""
|
||||
self._access_counter += 1
|
||||
self._access_order[key] = self._access_counter
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
self._update_access_time(key)
|
||||
entry.hit_count += 1
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
self._evict_if_needed()
|
||||
self._cache[key] = entry
|
||||
self._update_access_time(key)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._access_order.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
if namespace is None:
|
||||
return list(self._cache.items())
|
||||
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.namespace == namespace and not self._check_expired(v)
|
||||
]
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
candidates = [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.embedding is not None
|
||||
and not self._check_expired(v)
|
||||
and (namespace is None or v.namespace == namespace)
|
||||
]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, memory usage, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
memory_usage = sys.getsizeof(self._cache) + sum(
|
||||
sys.getsizeof(k) + sys.getsizeof(v) for k, v in self._cache.items()
|
||||
)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(self._cache),
|
||||
"memory_bytes": memory_usage,
|
||||
"max_size": self._max_size,
|
||||
}
|
||||
239
semantic_llm_cache/backends/redis.py
Normal file
239
semantic_llm_cache/backends/redis.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Redis distributed storage backend (async via redis.asyncio)."""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from redis import asyncio as aioredis
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Redis backend requires 'redis' package. "
|
||||
"Install with: pip install semantic-llm-cache[redis]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class RedisBackend(BaseBackend):
|
||||
"""Redis-based distributed cache storage (async).
|
||||
|
||||
Uses redis.asyncio (bundled with redis>=4.2) for non-blocking I/O.
|
||||
The connection is created in __init__; no explicit connect() call needed
|
||||
as redis.asyncio uses a connection pool that connects lazily.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX = "semantic_llm_cache:"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "redis://localhost:6379/0",
|
||||
prefix: str = DEFAULT_PREFIX,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize Redis backend.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL
|
||||
prefix: Key prefix for cache entries
|
||||
**kwargs: Additional arguments passed to redis.asyncio.from_url
|
||||
"""
|
||||
super().__init__()
|
||||
self._prefix = prefix.rstrip(":") + ":"
|
||||
self._redis = aioredis.from_url(url, **kwargs)
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""Test Redis connection. Call this after construction to verify connectivity.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If Redis is not reachable
|
||||
"""
|
||||
try:
|
||||
await self._redis.ping()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to connect to Redis: {e}") from e
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create full Redis key with prefix."""
|
||||
return f"{self._prefix}{key}"
|
||||
|
||||
def _entry_to_dict(self, entry: CacheEntry) -> dict[str, Any]:
|
||||
"""Convert CacheEntry to dictionary for storage."""
|
||||
return {
|
||||
"prompt": entry.prompt,
|
||||
"response": entry.response,
|
||||
"embedding": entry.embedding,
|
||||
"created_at": entry.created_at,
|
||||
"ttl": entry.ttl,
|
||||
"namespace": entry.namespace,
|
||||
"hit_count": entry.hit_count,
|
||||
"input_tokens": entry.input_tokens,
|
||||
"output_tokens": entry.output_tokens,
|
||||
}
|
||||
|
||||
def _dict_to_entry(self, data: dict[str, Any]) -> CacheEntry:
|
||||
"""Convert dictionary from storage to CacheEntry."""
|
||||
return CacheEntry(
|
||||
prompt=data["prompt"],
|
||||
response=data["response"],
|
||||
embedding=data.get("embedding"),
|
||||
created_at=data["created_at"],
|
||||
ttl=data.get("ttl"),
|
||||
namespace=data.get("namespace", "default"),
|
||||
hit_count=data.get("hit_count", 0),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = await self._redis.get(redis_key)
|
||||
|
||||
if data is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
entry_dict["hit_count"] = entry.hit_count
|
||||
await self._redis.set(redis_key, json.dumps(entry_dict))
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = json.dumps(self._entry_to_dict(entry))
|
||||
redis_ttl = entry.ttl if entry.ttl is not None else 0
|
||||
await self._redis.set(redis_key, data, ex=redis_ttl if redis_ttl > 0 else None)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await self._redis.delete(self._make_key(key))
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries with this prefix."""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
results = []
|
||||
|
||||
for full_key in keys:
|
||||
short_key = full_key.decode().replace(self._prefix, "", 1)
|
||||
data = await self._redis.get(full_key)
|
||||
|
||||
if data:
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if namespace is None or entry.namespace == namespace:
|
||||
if not self._check_expired(entry):
|
||||
results.append((short_key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Note: Loads all entries for cosine scan. For large datasets consider
|
||||
Redis Stack with vector search (RediSearch).
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics."""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(keys) if keys else 0,
|
||||
"prefix": self._prefix,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "prefix": self._prefix, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
try:
|
||||
await self._redis.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
279
semantic_llm_cache/backends/sqlite.py
Normal file
279
semantic_llm_cache/backends/sqlite.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""SQLite persistent storage backend (async via aiosqlite)."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"SQLite backend requires 'aiosqlite' package. "
|
||||
"Install with: pip install semantic-llm-cache[sqlite]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class SQLiteBackend(BaseBackend):
|
||||
"""SQLite-based persistent cache storage (async).
|
||||
|
||||
Uses aiosqlite for non-blocking I/O. A single persistent connection
|
||||
is opened lazily on first use and reused for all subsequent operations.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | Path = "semantic_cache.db") -> None:
|
||||
"""Initialize SQLite backend.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file, or ":memory:" for in-memory DB
|
||||
"""
|
||||
super().__init__()
|
||||
self._db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||
self._conn: Optional[aiosqlite.Connection] = None
|
||||
|
||||
async def _get_conn(self) -> aiosqlite.Connection:
|
||||
"""Get or create the persistent async connection."""
|
||||
if self._conn is None:
|
||||
self._conn = await aiosqlite.connect(self._db_path)
|
||||
self._conn.row_factory = aiosqlite.Row
|
||||
await self._initialize_schema()
|
||||
return self._conn
|
||||
|
||||
async def _initialize_schema(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
conn = await self._get_conn()
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS cache_entries (
|
||||
key TEXT PRIMARY KEY,
|
||||
prompt TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
ttl INTEGER,
|
||||
namespace TEXT NOT NULL DEFAULT 'default',
|
||||
hit_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_namespace
|
||||
ON cache_entries(namespace)
|
||||
"""
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
def _row_to_entry(self, row: aiosqlite.Row) -> CacheEntry:
|
||||
"""Convert database row to CacheEntry."""
|
||||
embedding = None
|
||||
if row["embedding"]:
|
||||
embedding = json.loads(row["embedding"])
|
||||
|
||||
return CacheEntry(
|
||||
prompt=row["prompt"],
|
||||
response=json.loads(row["response"]),
|
||||
embedding=embedding,
|
||||
created_at=row["created_at"],
|
||||
ttl=row["ttl"],
|
||||
namespace=row["namespace"],
|
||||
hit_count=row["hit_count"],
|
||||
input_tokens=row["input_tokens"],
|
||||
output_tokens=row["output_tokens"],
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"SELECT * FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry = self._row_to_entry(row)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
await conn.execute(
|
||||
"UPDATE cache_entries SET hit_count = hit_count + 1 WHERE key = ?",
|
||||
(key,),
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
embedding_json = json.dumps(entry.embedding) if entry.embedding else None
|
||||
response_json = json.dumps(entry.response)
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO cache_entries
|
||||
(key, prompt, response, embedding, created_at, ttl, namespace,
|
||||
hit_count, input_tokens, output_tokens)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
key,
|
||||
entry.prompt,
|
||||
response_json,
|
||||
embedding_json,
|
||||
entry.created_at,
|
||||
entry.ttl,
|
||||
entry.namespace,
|
||||
entry.hit_count,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
),
|
||||
)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"DELETE FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
await conn.commit()
|
||||
return rowcount > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
await conn.execute("DELETE FROM cache_entries")
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
|
||||
if namespace is None:
|
||||
query = "SELECT key, * FROM cache_entries"
|
||||
params: tuple[()] = ()
|
||||
else:
|
||||
query = "SELECT key, * FROM cache_entries WHERE namespace = ?"
|
||||
params = (namespace,)
|
||||
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
key = row["key"]
|
||||
entry = self._row_to_entry(row)
|
||||
if not self._check_expired(entry):
|
||||
results.append((key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, database path, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
size = row[0] if row else 0
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": size,
|
||||
"db_path": self._db_path,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "db_path": self._db_path, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._conn is not None:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
61
semantic_llm_cache/config.py
Normal file
61
semantic_llm_cache/config.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Configuration management for prompt-cache."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for cache behavior."""
|
||||
|
||||
similarity_threshold: float = 1.0 # 1.0 = exact match, lower = semantic
|
||||
ttl: Optional[int] = 3600 # Time to live in seconds, None = forever
|
||||
namespace: str = "default" # Isolate different use cases
|
||||
enabled: bool = True # Enable/disable caching
|
||||
key_func: Optional[Callable[[Any], str]] = None # Custom cache key function
|
||||
|
||||
# Cost estimation for statistics (USD per 1K tokens)
|
||||
input_cost_per_1k: float = 0.001 # Default ~$1/1M for cheaper models
|
||||
output_cost_per_1k: float = 0.002 # Default ~$2/1M for cheaper models
|
||||
|
||||
# Performance settings
|
||||
max_cache_size: Optional[int] = None # LRU eviction when set
|
||||
embedding_model: str = "all-MiniLM-L6-v2" # Default sentence-transformer model
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration."""
|
||||
if not 0.0 <= self.similarity_threshold <= 1.0:
|
||||
raise ValueError("similarity_threshold must be between 0.0 and 1.0")
|
||||
if self.ttl is not None and self.ttl <= 0:
|
||||
raise ValueError("ttl must be positive or None")
|
||||
if self.max_cache_size is not None and self.max_cache_size <= 0:
|
||||
raise ValueError("max_cache_size must be positive or None")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached response with metadata."""
|
||||
|
||||
prompt: str
|
||||
response: Any
|
||||
embedding: Optional[list[float]] = None # Normalized embedding vector
|
||||
created_at: float = 0.0 # Unix timestamp
|
||||
ttl: Optional[int] = None # Time to live in seconds
|
||||
namespace: str = "default"
|
||||
hit_count: int = 0
|
||||
|
||||
# Approximate token counts for cost estimation
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
def is_expired(self, current_time: float) -> bool:
|
||||
"""Check if entry has expired based on TTL."""
|
||||
if self.ttl is None:
|
||||
return False
|
||||
return (current_time - self.created_at) > self.ttl
|
||||
|
||||
def estimate_cost(self, input_cost: float, output_cost: float) -> float:
|
||||
"""Estimate cost savings in USD."""
|
||||
input_savings = (self.input_tokens / 1000) * input_cost
|
||||
output_savings = (self.output_tokens / 1000) * output_cost
|
||||
return input_savings + output_savings
|
||||
369
semantic_llm_cache/core.py
Normal file
369
semantic_llm_cache/core.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
"""Core cache decorator and API for llm-semantic-cache."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheConfig, CacheEntry
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
from semantic_llm_cache.similarity import EmbeddingCache
|
||||
from semantic_llm_cache.stats import _stats_manager
|
||||
from semantic_llm_cache.utils import hash_prompt, normalize_prompt
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _extract_prompt(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
"""Extract prompt string from function arguments."""
|
||||
if args and isinstance(args[0], str):
|
||||
return args[0]
|
||||
if "prompt" in kwargs:
|
||||
return str(kwargs["prompt"])
|
||||
return str(args) + str(sorted(kwargs.items()))
|
||||
|
||||
|
||||
class CacheContext:
|
||||
"""Context manager for cache configuration.
|
||||
|
||||
Supports both sync (with) and async (async with) usage.
|
||||
|
||||
Examples:
|
||||
>>> async with CacheContext(similarity=0.9) as ctx:
|
||||
... result = await llm_call("prompt")
|
||||
... print(ctx.stats)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity: Optional[float] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity if similarity is not None else 1.0,
|
||||
ttl=ttl,
|
||||
namespace=namespace if namespace is not None else "default",
|
||||
enabled=enabled if enabled is not None else True,
|
||||
)
|
||||
self._stats: dict[str, Any] = {"hits": 0, "misses": 0}
|
||||
|
||||
def __enter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def stats(self) -> dict[str, Any]:
|
||||
return self._stats.copy()
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
return self._config
|
||||
|
||||
|
||||
class CachedLLM:
|
||||
"""Wrapper class for LLM calls with automatic caching.
|
||||
|
||||
Examples:
|
||||
>>> llm = CachedLLM(similarity=0.9)
|
||||
>>> response = await llm.achat("What is Python?", llm_func=my_async_llm)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: str = "gpt-4",
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._backend = backend or MemoryBackend()
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity,
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_func: Optional[Callable[[str], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Get response with caching (async).
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
llm_func: Async or sync LLM function to call on cache miss
|
||||
**kwargs: Additional arguments for llm_func
|
||||
|
||||
Returns:
|
||||
LLM response (cached or fresh)
|
||||
"""
|
||||
if llm_func is None:
|
||||
raise ValueError("llm_func is required for CachedLLM.achat()")
|
||||
|
||||
@cache(
|
||||
similarity=self._config.similarity_threshold,
|
||||
ttl=self._config.ttl,
|
||||
backend=self._backend,
|
||||
namespace=self._config.namespace,
|
||||
enabled=self._config.enabled,
|
||||
)
|
||||
async def _cached_call(p: str) -> Any:
|
||||
result = llm_func(p, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
return await _cached_call(prompt)
|
||||
|
||||
|
||||
def cache(
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator for caching LLM function responses.
|
||||
|
||||
Auto-detects whether the decorated function is async or sync and returns
|
||||
the appropriate wrapper. Both variants share identical cache logic.
|
||||
|
||||
Async functions get a true async wrapper (awaits all backend calls).
|
||||
Sync functions get a sync wrapper that drives the async backends via a
|
||||
temporary event loop — not suitable inside a running loop; prefer decorating
|
||||
async functions when integrating with async frameworks like FastAPI.
|
||||
|
||||
Args:
|
||||
similarity: Cosine similarity threshold (1.0=exact, 0.9=semantic)
|
||||
ttl: Time-to-live in seconds (None=forever)
|
||||
backend: Async storage backend (None=in-memory)
|
||||
namespace: Cache namespace for isolation
|
||||
enabled: Whether caching is enabled
|
||||
key_func: Custom cache key function
|
||||
|
||||
Returns:
|
||||
Decorated function with caching
|
||||
|
||||
Examples:
|
||||
>>> @cache(similarity=0.9, ttl=3600)
|
||||
... async def ask_llm(prompt: str) -> str:
|
||||
... return await call_ollama(prompt)
|
||||
|
||||
>>> @cache()
|
||||
... def ask_llm_sync(prompt: str) -> str:
|
||||
... return call_ollama_sync(prompt)
|
||||
"""
|
||||
_backend = backend or MemoryBackend()
|
||||
embedding_cache = EmbeddingCache()
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
# ── Async wrapper ────────────────────────────────────────────────
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
# 1. Exact match
|
||||
entry = await _backend.get(cache_key)
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = await embedding_cache.aencode(normalized)
|
||||
result = await _backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss — call through
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = await embedding_cache.aencode(normalized)
|
||||
|
||||
await _backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
# ── Sync wrapper (backwards compatibility) ───────────────────────
|
||||
# Drives async backends via a dedicated event loop per call.
|
||||
# Do NOT use inside a running event loop (e.g. FastAPI handlers).
|
||||
import asyncio
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# 1. Exact match
|
||||
entry = loop.run_until_complete(_backend.get(cache_key))
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = embedding_cache.encode(normalized)
|
||||
result = loop.run_until_complete(
|
||||
_backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = embedding_cache.encode(normalized)
|
||||
|
||||
loop.run_until_complete(
|
||||
_backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global default backend for utility functions
|
||||
_default_backend: Optional[BaseBackend] = None
|
||||
|
||||
|
||||
def get_default_backend() -> BaseBackend:
|
||||
"""Get default storage backend."""
|
||||
global _default_backend
|
||||
if _default_backend is None:
|
||||
_default_backend = MemoryBackend()
|
||||
return _default_backend
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend) -> None:
|
||||
"""Set default storage backend."""
|
||||
global _default_backend
|
||||
_default_backend = backend
|
||||
_stats_manager.set_backend(backend)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cache",
|
||||
"CacheContext",
|
||||
"CachedLLM",
|
||||
"get_default_backend",
|
||||
"set_default_backend",
|
||||
]
|
||||
25
semantic_llm_cache/exceptions.py
Normal file
25
semantic_llm_cache/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Custom exceptions for prompt-cache."""
|
||||
|
||||
|
||||
class PromptCacheError(Exception):
|
||||
"""Base exception for prompt-cache errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheBackendError(PromptCacheError):
|
||||
"""Exception raised when backend operations fail."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheSerializationError(PromptCacheError):
|
||||
"""Exception raised when serialization/deserialization fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheNotFoundError(PromptCacheError):
|
||||
"""Exception raised when cache entry is not found."""
|
||||
|
||||
pass
|
||||
1
semantic_llm_cache/py.typed
Normal file
1
semantic_llm_cache/py.typed
Normal file
|
|
@ -0,0 +1 @@
|
|||
# PEP 561 marker file for type hints
|
||||
283
semantic_llm_cache/similarity.py
Normal file
283
semantic_llm_cache/similarity.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Embedding generation and similarity matching for llm-semantic-cache."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyEmbeddingProvider(EmbeddingProvider):
|
||||
"""Fallback embedding provider using hash-based vectors.
|
||||
|
||||
Provides consistent embeddings without external dependencies.
|
||||
Not semantically meaningful but provides consistent cache keys.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 384) -> None:
|
||||
"""Initialize dummy provider.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (matches MiniLM default)
|
||||
"""
|
||||
self._dim = dim
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate hash-based embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Deterministic embedding vector based on text hash
|
||||
"""
|
||||
hash_obj = hashlib.sha256(text.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
if len(values) < self._dim:
|
||||
values = np.pad(values, (0, self._dim - len(values)))
|
||||
|
||||
norm = np.linalg.norm(values)
|
||||
if norm > 0:
|
||||
values = values / norm
|
||||
|
||||
return values.tolist()
|
||||
|
||||
|
||||
class SentenceTransformerProvider(EmbeddingProvider):
|
||||
"""Sentence-transformers based embedding provider.
|
||||
|
||||
Uses local models like MiniLM for semantic embeddings.
|
||||
Inference is CPU/GPU-bound; use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
|
||||
"""Initialize sentence-transformer provider.
|
||||
|
||||
Args:
|
||||
model_name: Name of sentence-transformer model
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"sentence-transformers package required for semantic matching. "
|
||||
"Install with: pip install semantic-llm-cache[semantic]"
|
||||
) from e
|
||||
|
||||
self._model = SentenceTransformer(model_name)
|
||||
self._dim = self._model.get_sentence_embedding_dimension()
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Normalized embedding vector
|
||||
"""
|
||||
embedding = self._model.encode(text, convert_to_numpy=True)
|
||||
embedding = np.asarray(embedding, dtype=np.float32)
|
||||
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI API-based embedding provider.
|
||||
|
||||
Uses OpenAI's embedding API for high-quality semantic embeddings.
|
||||
Network I/O — always use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, model: str = "text-embedding-3-small"
|
||||
) -> None:
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key (uses OPENAI_API_KEY env var if None)
|
||||
model: OpenAI embedding model to use
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"openai package required for OpenAI embeddings. "
|
||||
"Install with: pip install semantic-llm-cache[openai]"
|
||||
) from e
|
||||
|
||||
self._client = openai.OpenAI(api_key=api_key)
|
||||
self._model = model
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
OpenAI embedding vector (already normalized)
|
||||
"""
|
||||
response = self._client.embeddings.create(input=text, model=self._model)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
embedding_arr = np.asarray(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(embedding_arr)
|
||||
if norm > 0:
|
||||
embedding_arr = embedding_arr / norm
|
||||
|
||||
return embedding_arr.tolist()
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If vectors have different dimensions
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
if a_arr.shape != b_arr.shape:
|
||||
raise ValueError(
|
||||
f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}"
|
||||
)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]:
|
||||
"""Helper function for LRU cache encoding.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
provider: Embedding provider
|
||||
|
||||
Returns:
|
||||
Embedding as tuple for hashability
|
||||
"""
|
||||
return tuple(provider.encode(text))
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embedding generation with LRU eviction.
|
||||
|
||||
Use encode() from sync contexts, aencode() from async contexts.
|
||||
aencode() offloads blocking inference to a thread pool via asyncio.to_thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: Optional[EmbeddingProvider] = None,
|
||||
cache_size: int = 1024,
|
||||
) -> None:
|
||||
"""Initialize embedding cache.
|
||||
|
||||
Args:
|
||||
provider: Embedding provider (uses DummyEmbeddingProvider if None)
|
||||
cache_size: Maximum number of embeddings to cache
|
||||
"""
|
||||
self._provider = provider or DummyEmbeddingProvider()
|
||||
self._cache_size = cache_size
|
||||
self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider)
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (sync, blocking).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return list(self._get_cached(text, self._provider))
|
||||
|
||||
async def aencode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (async, non-blocking).
|
||||
|
||||
CPU/network-bound work is offloaded to the default thread pool via
|
||||
asyncio.to_thread, keeping the event loop free.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return await asyncio.to_thread(self.encode, text)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the embedding LRU cache."""
|
||||
self._get_cached.cache_clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "auto",
|
||||
model_name: Optional[str] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create embedding provider based on type.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy")
|
||||
model_name: Optional model name to use
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider_type == "auto":
|
||||
try:
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
except PromptCacheError:
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
if provider_type == "sentence-transformer":
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(model=model_name)
|
||||
|
||||
if provider_type == "dummy":
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
255
semantic_llm_cache/stats.py
Normal file
255
semantic_llm_cache/stats.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""Statistics and analytics for llm-semantic-cache."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics for cache performance."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
total_saved_ms: float = 0.0
|
||||
estimated_savings_usd: float = 0.0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
return self.hits / max(total, 1)
|
||||
|
||||
@property
|
||||
def total_requests(self) -> int:
|
||||
"""Get total number of requests."""
|
||||
return self.hits + self.misses
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"total_requests": self.total_requests,
|
||||
"total_saved_ms": self.total_saved_ms,
|
||||
"estimated_savings_usd": self.estimated_savings_usd,
|
||||
}
|
||||
|
||||
def __iadd__(self, other: "CacheStats") -> "CacheStats":
|
||||
self.hits += other.hits
|
||||
self.misses += other.misses
|
||||
self.total_saved_ms += other.total_saved_ms
|
||||
self.estimated_savings_usd += other.estimated_savings_usd
|
||||
return self
|
||||
|
||||
|
||||
class _StatsManager:
|
||||
"""Manager for global cache statistics.
|
||||
|
||||
Uses threading.Lock for record_hit/record_miss — these are simple counter
|
||||
increments with no awaits inside the lock, so threading.Lock is safe and
|
||||
avoids the overhead of asyncio.Lock for hot-path calls.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize stats manager."""
|
||||
self._stats: dict[str, CacheStats] = {}
|
||||
self._lock = Lock()
|
||||
self._default_backend: Optional[BaseBackend] = None
|
||||
|
||||
def get_backend(self) -> BaseBackend:
|
||||
"""Get default backend for cache operations."""
|
||||
if self._default_backend is None:
|
||||
self._default_backend = MemoryBackend()
|
||||
return self._default_backend
|
||||
|
||||
def set_backend(self, backend: BaseBackend) -> None:
|
||||
"""Set default backend for cache operations."""
|
||||
with self._lock:
|
||||
self._default_backend = backend
|
||||
|
||||
def record_hit(
|
||||
self,
|
||||
namespace: str,
|
||||
latency_saved_ms: float = 0.0,
|
||||
saved_cost: float = 0.0,
|
||||
) -> None:
|
||||
"""Record a cache hit (sync, safe to call from async context)."""
|
||||
with self._lock:
|
||||
if namespace not in self._stats:
|
||||
self._stats[namespace] = CacheStats()
|
||||
stats = self._stats[namespace]
|
||||
stats.hits += 1
|
||||
stats.total_saved_ms += latency_saved_ms
|
||||
stats.estimated_savings_usd += saved_cost
|
||||
|
||||
def record_miss(self, namespace: str) -> None:
|
||||
"""Record a cache miss (sync, safe to call from async context)."""
|
||||
with self._lock:
|
||||
if namespace not in self._stats:
|
||||
self._stats[namespace] = CacheStats()
|
||||
self._stats[namespace].misses += 1
|
||||
|
||||
def get_stats(self, namespace: Optional[str] = None) -> CacheStats:
|
||||
"""Get statistics for namespace or all."""
|
||||
with self._lock:
|
||||
if namespace is not None:
|
||||
return self._stats.get(namespace, CacheStats())
|
||||
|
||||
total = CacheStats()
|
||||
for stats in self._stats.values():
|
||||
total += stats
|
||||
return total
|
||||
|
||||
def clear_stats(self, namespace: Optional[str] = None) -> None:
|
||||
"""Clear statistics for namespace or all."""
|
||||
with self._lock:
|
||||
if namespace is None:
|
||||
self._stats.clear()
|
||||
elif namespace in self._stats:
|
||||
del self._stats[namespace]
|
||||
|
||||
|
||||
# Global stats manager instance
|
||||
_stats_manager = _StatsManager()
|
||||
|
||||
|
||||
def get_stats(namespace: Optional[str] = None) -> dict[str, Any]:
|
||||
"""Get cache statistics (sync).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace to filter by
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
return _stats_manager.get_stats(namespace).to_dict()
|
||||
|
||||
|
||||
async def clear_cache(namespace: Optional[str] = None) -> int:
|
||||
"""Clear all cached entries (async).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace to clear (None = all)
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
if namespace is None:
|
||||
stats = await backend.get_stats()
|
||||
size = stats.get("size", 0)
|
||||
await backend.clear()
|
||||
_stats_manager.clear_stats()
|
||||
return size
|
||||
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
count = len(entries)
|
||||
for key, _ in entries:
|
||||
await backend.delete(key)
|
||||
_stats_manager.clear_stats(namespace)
|
||||
return count
|
||||
|
||||
|
||||
async def invalidate(
|
||||
pattern: str,
|
||||
namespace: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Invalidate cache entries matching pattern (async).
|
||||
|
||||
Args:
|
||||
pattern: String pattern to match in prompts
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
backend = _stats_manager.get_backend()
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
count = 0
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
for key, entry in entries:
|
||||
if pattern_lower in entry.prompt.lower():
|
||||
await backend.delete(key)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def warm_cache(
|
||||
prompts: list[str],
|
||||
llm_func: Callable[[str], Any],
|
||||
namespace: str = "default",
|
||||
) -> int:
|
||||
"""Pre-populate cache with prompts (async).
|
||||
|
||||
Args:
|
||||
prompts: List of prompts to cache
|
||||
llm_func: Async or sync LLM function to call for each prompt
|
||||
namespace: Cache namespace to use
|
||||
|
||||
Returns:
|
||||
Number of prompts attempted
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
from semantic_llm_cache.core import cache
|
||||
|
||||
cached_func = cache(namespace=namespace)(llm_func)
|
||||
|
||||
for prompt in prompts:
|
||||
try:
|
||||
result = cached_func(prompt)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return len(prompts)
|
||||
|
||||
|
||||
async def export_cache(
|
||||
namespace: Optional[str] = None,
|
||||
filepath: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Export cache entries for analysis (async).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
filepath: Optional file path to save export (JSON)
|
||||
|
||||
Returns:
|
||||
List of cache entry dictionaries
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
backend = _stats_manager.get_backend()
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
|
||||
export_data = []
|
||||
for key, entry in entries:
|
||||
export_data.append({
|
||||
"key": key,
|
||||
"prompt": entry.prompt,
|
||||
"response": str(entry.response)[:1000],
|
||||
"namespace": entry.namespace,
|
||||
"hit_count": entry.hit_count,
|
||||
"created_at": datetime.fromtimestamp(entry.created_at).isoformat(),
|
||||
"ttl": entry.ttl,
|
||||
"input_tokens": entry.input_tokens,
|
||||
"output_tokens": entry.output_tokens,
|
||||
})
|
||||
|
||||
if filepath:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
|
||||
return export_data
|
||||
111
semantic_llm_cache/storage.py
Normal file
111
semantic_llm_cache/storage.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Storage backend interface for prompt-cache."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Abstract base class for async cache storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def iterate(self, namespace: Optional[str] = None) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with stats like size, memory_usage, etc.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
97
semantic_llm_cache/utils/__init__.py
Normal file
97
semantic_llm_cache/utils/__init__.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Utility functions for prompt-cache."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_prompt(prompt: str) -> str:
|
||||
"""Normalize prompt text for consistent caching.
|
||||
|
||||
Args:
|
||||
prompt: Raw prompt text
|
||||
|
||||
Returns:
|
||||
Normalized prompt text
|
||||
"""
|
||||
# Remove extra whitespace
|
||||
prompt = " ".join(prompt.split())
|
||||
|
||||
# Lowercase for better matching (optional - can affect semantics)
|
||||
# prompt = prompt.lower()
|
||||
|
||||
# Remove common filler words at start
|
||||
filler_pattern = r"^(please|can you|could you|i need|i want)\s+"
|
||||
prompt = re.sub(filler_pattern, "", prompt, flags=re.IGNORECASE)
|
||||
|
||||
# Normalize quotes
|
||||
prompt = prompt.replace('"', "'").replace("`", "'")
|
||||
|
||||
# Remove trailing punctuation
|
||||
prompt = prompt.rstrip("?!.")
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
def hash_prompt(prompt: str, namespace: str = "default") -> str:
|
||||
"""Generate cache key from prompt and namespace.
|
||||
|
||||
Args:
|
||||
prompt: Prompt text
|
||||
namespace: Cache namespace
|
||||
|
||||
Returns:
|
||||
Hash-based cache key
|
||||
"""
|
||||
combined = f"{namespace}:{prompt}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough approximation: ~4 chars per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def serialize_response(response: Any) -> str:
|
||||
"""Serialize response for storage.
|
||||
|
||||
Args:
|
||||
response: Response object (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
Serialized JSON string
|
||||
"""
|
||||
import json
|
||||
|
||||
return json.dumps(response)
|
||||
|
||||
|
||||
def deserialize_response(data: str) -> Any:
|
||||
"""Deserialize response from storage.
|
||||
|
||||
Args:
|
||||
data: Serialized JSON string
|
||||
|
||||
Returns:
|
||||
Deserialized response object
|
||||
"""
|
||||
import json
|
||||
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"normalize_prompt",
|
||||
"hash_prompt",
|
||||
"estimate_tokens",
|
||||
"serialize_response",
|
||||
"deserialize_response",
|
||||
]
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Tests for prompt-cache."""
|
||||
50
tests/conftest.py
Normal file
50
tests/conftest.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Pytest configuration and fixtures for prompt-cache tests."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.config import CacheConfig, CacheEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend():
|
||||
"""Provide a fresh memory backend for each test."""
|
||||
return MemoryBackend()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_config():
|
||||
"""Provide default cache configuration."""
|
||||
return CacheConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entry():
|
||||
"""Provide a sample cache entry."""
|
||||
return CacheEntry(
|
||||
prompt="What is Python?",
|
||||
response="Python is a programming language.",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
created_at=time.time(), # Use current time
|
||||
ttl=3600,
|
||||
namespace="default",
|
||||
hit_count=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_func():
|
||||
"""Provide a mock LLM function."""
|
||||
responses = {
|
||||
"What is Python?": "Python is a programming language.",
|
||||
"What's Python?": "Python is a programming language.",
|
||||
"Explain Python": "Python is a high-level programming language.",
|
||||
"What is Rust?": "Rust is a systems programming language.",
|
||||
}
|
||||
|
||||
def _func(prompt: str) -> str:
|
||||
return responses.get(prompt, f"Response to: {prompt}")
|
||||
|
||||
return _func
|
||||
687
tests/test_backends.py
Normal file
687
tests/test_backends.py
Normal file
|
|
@ -0,0 +1,687 @@
|
|||
"""Tests for storage backends."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.sqlite import SQLiteBackend # noqa: F401
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class TestBaseBackend:
|
||||
"""Tests for BaseBackend abstract class."""
|
||||
|
||||
def test_cosine_similarity(self):
|
||||
"""Test cosine similarity helper method."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
entry1 = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
embedding=[1.0, 0.0, 0.0],
|
||||
)
|
||||
entry2 = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
embedding=[1.0, 0.0, 0.0],
|
||||
)
|
||||
entry3 = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
embedding=[0.0, 1.0, 0.0],
|
||||
)
|
||||
|
||||
# Test _find_best_match
|
||||
candidates = [("key1", entry1), ("key2", entry2), ("key3", entry3)]
|
||||
|
||||
# Query matching entry1
|
||||
result = backend._find_best_match(candidates, [1.0, 0.0, 0.0], threshold=0.9)
|
||||
assert result is not None
|
||||
key, entry, sim = result
|
||||
assert sim == pytest.approx(1.0)
|
||||
|
||||
|
||||
class TestMemoryBackend:
|
||||
"""Tests for MemoryBackend."""
|
||||
|
||||
def test_set_and_get(self, backend, sample_entry):
|
||||
"""Test basic set and get operations."""
|
||||
backend.set("key1", sample_entry)
|
||||
retrieved = backend.get("key1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.prompt == sample_entry.prompt
|
||||
assert retrieved.response == sample_entry.response
|
||||
|
||||
def test_get_nonexistent(self, backend):
|
||||
"""Test getting non-existent key returns None."""
|
||||
result = backend.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_delete(self, backend, sample_entry):
|
||||
"""Test delete operation."""
|
||||
backend.set("key1", sample_entry)
|
||||
assert backend.get("key1") is not None
|
||||
|
||||
assert backend.delete("key1") is True
|
||||
assert backend.get("key1") is None
|
||||
|
||||
def test_delete_nonexistent(self, backend):
|
||||
"""Test deleting non-existent key returns False."""
|
||||
assert backend.delete("nonexistent") is False
|
||||
|
||||
def test_clear(self, backend, sample_entry):
|
||||
"""Test clear operation."""
|
||||
backend.set("key1", sample_entry)
|
||||
backend.set("key2", sample_entry)
|
||||
backend.clear()
|
||||
|
||||
assert backend.get("key1") is None
|
||||
assert backend.get("key2") is None
|
||||
|
||||
def test_iterate_all(self, backend):
|
||||
"""Test iterating over all entries."""
|
||||
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
|
||||
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
|
||||
results = backend.iterate()
|
||||
assert len(results) == 2
|
||||
|
||||
def test_iterate_with_namespace(self, backend):
|
||||
"""Test iterating with namespace filter."""
|
||||
entry1 = CacheEntry(
|
||||
prompt="p1", response="r1", namespace="ns1", created_at=time.time()
|
||||
)
|
||||
entry2 = CacheEntry(
|
||||
prompt="p2", response="r2", namespace="ns2", created_at=time.time()
|
||||
)
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
|
||||
results = backend.iterate(namespace="ns1")
|
||||
assert len(results) == 1
|
||||
assert results[0][1].namespace == "ns1"
|
||||
|
||||
def test_find_similar(self, backend):
|
||||
"""Test finding semantically similar entries."""
|
||||
entry1 = CacheEntry(
|
||||
prompt="What is Python?",
|
||||
response="r1",
|
||||
embedding=[1.0, 0.0, 0.0],
|
||||
created_at=time.time(),
|
||||
)
|
||||
entry2 = CacheEntry(
|
||||
prompt="What is Rust?",
|
||||
response="r2",
|
||||
embedding=[0.0, 1.0, 0.0],
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
|
||||
# Find similar to entry1
|
||||
result = backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
|
||||
assert result is not None
|
||||
key, entry, sim = result
|
||||
assert key == "key1"
|
||||
|
||||
def test_find_similar_no_match(self, backend):
|
||||
"""Test find_similar returns None when below threshold."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
embedding=[1.0, 0.0, 0.0],
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
backend.set("key1", entry)
|
||||
|
||||
# Query with orthogonal vector
|
||||
result = backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
|
||||
assert result is None
|
||||
|
||||
def test_get_stats(self, backend):
|
||||
"""Test get_stats returns correct info."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
backend.set("key1", entry)
|
||||
stats = backend.get_stats()
|
||||
|
||||
assert stats["size"] == 1
|
||||
assert "hits" in stats
|
||||
assert "misses" in stats
|
||||
|
||||
def test_hit_count_increments(self, backend, sample_entry):
|
||||
"""Test hit count increments on cache hit."""
|
||||
backend.set("key1", sample_entry)
|
||||
|
||||
backend.get("key1") # First hit
|
||||
backend.get("key1") # Second hit
|
||||
|
||||
entry = backend.get("key1")
|
||||
assert entry.hit_count >= 1
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction when max_size is reached."""
|
||||
backend = MemoryBackend(max_size=2)
|
||||
|
||||
entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time())
|
||||
entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time())
|
||||
entry3 = CacheEntry(prompt="p3", response="r3", created_at=time.time())
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
assert backend.get_stats()["size"] == 2
|
||||
|
||||
backend.set("key3", entry3)
|
||||
# Should evict oldest (key1)
|
||||
assert backend.get_stats()["size"] == 2
|
||||
|
||||
def test_expired_entry_not_returned(self, backend):
|
||||
"""Test expired entries are not returned."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=1,
|
||||
created_at=time.time() - 2, # 2 seconds ago with 1s TTL
|
||||
)
|
||||
|
||||
backend.set("key1", entry)
|
||||
result = backend.get("key1")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSQLiteBackend:
|
||||
"""Tests for SQLiteBackend."""
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_backend(self, tmp_path):
|
||||
"""Create SQLite backend with temp database."""
|
||||
from semantic_llm_cache.backends.sqlite import SQLiteBackend
|
||||
|
||||
db_path = tmp_path / "test_cache.db"
|
||||
return SQLiteBackend(db_path)
|
||||
|
||||
def test_set_and_get(self, sqlite_backend, sample_entry):
|
||||
"""Test basic set and get operations."""
|
||||
sqlite_backend.set("key1", sample_entry)
|
||||
retrieved = sqlite_backend.get("key1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.prompt == sample_entry.prompt
|
||||
assert retrieved.response == sample_entry.response
|
||||
|
||||
def test_persistence(self, sqlite_backend, sample_entry, tmp_path):
|
||||
"""Test entries persist across backend instances."""
|
||||
db_path = tmp_path / "test_persist.db"
|
||||
|
||||
# Create first instance
|
||||
backend1 = SQLiteBackend(db_path)
|
||||
backend1.set("key1", sample_entry)
|
||||
|
||||
# Create second instance (simulates restart)
|
||||
backend2 = SQLiteBackend(db_path)
|
||||
retrieved = backend2.get("key1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.prompt == sample_entry.prompt
|
||||
|
||||
def test_get_stats(self, sqlite_backend):
|
||||
"""Test get_stats returns correct info."""
|
||||
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
|
||||
sqlite_backend.set("key1", entry)
|
||||
|
||||
stats = sqlite_backend.get_stats()
|
||||
assert stats["size"] == 1
|
||||
assert "db_path" in stats
|
||||
|
||||
def test_clear(self, sqlite_backend, sample_entry):
|
||||
"""Test clear operation."""
|
||||
sqlite_backend.set("key1", sample_entry)
|
||||
sqlite_backend.clear()
|
||||
|
||||
assert sqlite_backend.get("key1") is None
|
||||
|
||||
def test_close_and_reopen(self, sqlite_backend, sample_entry):
|
||||
"""Test closing and reopening connection."""
|
||||
sqlite_backend.set("key1", sample_entry)
|
||||
sqlite_backend.close()
|
||||
|
||||
# Should be able to use after close (reopens connection)
|
||||
retrieved = sqlite_backend.get("key1")
|
||||
assert retrieved is not None
|
||||
|
||||
|
||||
class TestRedisBackend:
|
||||
"""Tests for RedisBackend."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create mock Redis client."""
|
||||
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock:
|
||||
mock_client = MagicMock()
|
||||
mock.from_url.return_value = mock_client
|
||||
mock_client.ping.return_value = True
|
||||
mock_client.get.return_value = None
|
||||
mock_client.keys.return_value = []
|
||||
mock_client.delete.return_value = 1
|
||||
yield mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def redis_backend(self, mock_redis):
|
||||
"""Create Redis backend with mocked client."""
|
||||
from semantic_llm_cache.backends.redis import RedisBackend
|
||||
|
||||
backend = RedisBackend(url="redis://localhost:6379/0")
|
||||
backend._redis = mock_redis
|
||||
return backend
|
||||
|
||||
def test_set_and_get(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test basic set and get operations."""
|
||||
# Mock get to return stored data
|
||||
mock_redis.get.return_value = json.dumps({
|
||||
"prompt": sample_entry.prompt,
|
||||
"response": sample_entry.response,
|
||||
"embedding": sample_entry.embedding,
|
||||
"created_at": sample_entry.created_at,
|
||||
"ttl": sample_entry.ttl,
|
||||
"namespace": sample_entry.namespace,
|
||||
"hit_count": 0,
|
||||
"input_tokens": sample_entry.input_tokens,
|
||||
"output_tokens": sample_entry.output_tokens,
|
||||
}).encode()
|
||||
|
||||
redis_backend.set("key1", sample_entry)
|
||||
retrieved = redis_backend.get("key1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.prompt == sample_entry.prompt
|
||||
# set is called twice: once for initial set, once to update hit_count
|
||||
assert mock_redis.set.call_count == 2
|
||||
|
||||
def test_get_nonexistent(self, redis_backend, mock_redis):
|
||||
"""Test getting non-existent key returns None."""
|
||||
mock_redis.get.return_value = None
|
||||
result = redis_backend.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_delete(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test delete operation."""
|
||||
mock_redis.delete.return_value = 1
|
||||
result = redis_backend.delete("key1")
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
def test_delete_nonexistent(self, redis_backend, mock_redis):
|
||||
"""Test deleting non-existent key returns False."""
|
||||
mock_redis.delete.return_value = 0
|
||||
result = redis_backend.delete("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_clear(self, redis_backend, mock_redis):
|
||||
"""Test clear operation."""
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||||
redis_backend.clear()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
def test_clear_empty(self, redis_backend, mock_redis):
|
||||
"""Test clear with no entries."""
|
||||
mock_redis.keys.return_value = []
|
||||
redis_backend.clear()
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_iterate_all(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test iterating over all entries."""
|
||||
entry_dict = {
|
||||
"prompt": sample_entry.prompt,
|
||||
"response": sample_entry.response,
|
||||
"embedding": sample_entry.embedding,
|
||||
"created_at": sample_entry.created_at,
|
||||
"ttl": sample_entry.ttl,
|
||||
"namespace": sample_entry.namespace,
|
||||
"hit_count": 0,
|
||||
"input_tokens": sample_entry.input_tokens,
|
||||
"output_tokens": sample_entry.output_tokens,
|
||||
}
|
||||
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||||
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||||
|
||||
results = redis_backend.iterate()
|
||||
assert len(results) == 2
|
||||
|
||||
def test_iterate_with_namespace(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test iterating with namespace filter."""
|
||||
entry1_dict = {
|
||||
"prompt": "p1",
|
||||
"response": "r1",
|
||||
"embedding": None,
|
||||
"created_at": time.time(),
|
||||
"ttl": None,
|
||||
"namespace": "ns1",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
entry2_dict = entry1_dict.copy()
|
||||
entry2_dict["namespace"] = "ns2"
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_get(key):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return json.dumps(entry1_dict).encode()
|
||||
return json.dumps(entry2_dict).encode()
|
||||
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||||
mock_redis.get.side_effect = mock_get
|
||||
|
||||
results = redis_backend.iterate(namespace="ns1")
|
||||
assert len(results) == 1
|
||||
assert results[0][1].namespace == "ns1"
|
||||
|
||||
def test_find_similar(self, redis_backend, mock_redis):
|
||||
"""Test finding semantically similar entries."""
|
||||
entry_dict = {
|
||||
"prompt": "What is Python?",
|
||||
"response": "r1",
|
||||
"embedding": [1.0, 0.0, 0.0],
|
||||
"created_at": time.time(),
|
||||
"ttl": None,
|
||||
"namespace": "default",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
|
||||
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||||
|
||||
result = redis_backend.find_similar([1.0, 0.0, 0.0], threshold=0.9)
|
||||
assert result is not None
|
||||
key, entry, sim = result
|
||||
assert key == "key1"
|
||||
|
||||
def test_find_similar_no_match(self, redis_backend, mock_redis):
|
||||
"""Test find_similar returns None when below threshold."""
|
||||
entry_dict = {
|
||||
"prompt": "test",
|
||||
"response": "response",
|
||||
"embedding": [1.0, 0.0, 0.0],
|
||||
"created_at": time.time(),
|
||||
"ttl": None,
|
||||
"namespace": "default",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1"]
|
||||
mock_redis.get.return_value = json.dumps(entry_dict).encode()
|
||||
|
||||
# Query with orthogonal vector
|
||||
result = redis_backend.find_similar([0.0, 1.0, 0.0], threshold=0.9)
|
||||
assert result is None
|
||||
|
||||
def test_get_stats(self, redis_backend, mock_redis):
|
||||
"""Test get_stats returns correct info."""
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"]
|
||||
stats = redis_backend.get_stats()
|
||||
|
||||
assert "prefix" in stats
|
||||
assert stats["size"] == 2
|
||||
assert stats["prefix"] == "semantic_llm_cache:"
|
||||
|
||||
def test_get_stats_error_handling(self, redis_backend, mock_redis):
|
||||
"""Test get_stats handles Redis errors gracefully."""
|
||||
mock_redis.keys.side_effect = Exception("Connection lost")
|
||||
stats = redis_backend.get_stats()
|
||||
|
||||
assert "error" in stats
|
||||
assert stats["size"] == 0
|
||||
|
||||
def test_make_key(self, redis_backend):
|
||||
"""Test key prefixing."""
|
||||
result = redis_backend._make_key("test_key")
|
||||
assert result == "semantic_llm_cache:test_key"
|
||||
|
||||
def test_entry_to_dict(self, redis_backend, sample_entry):
|
||||
"""Test converting entry to dictionary."""
|
||||
result = redis_backend._entry_to_dict(sample_entry)
|
||||
assert result["prompt"] == sample_entry.prompt
|
||||
assert result["response"] == sample_entry.response
|
||||
assert result["embedding"] == sample_entry.embedding
|
||||
|
||||
def test_dict_to_entry(self, redis_backend):
|
||||
"""Test converting dictionary to entry."""
|
||||
data = {
|
||||
"prompt": "test",
|
||||
"response": "response",
|
||||
"embedding": [1.0, 0.0],
|
||||
"created_at": time.time(),
|
||||
"ttl": 100,
|
||||
"namespace": "test_ns",
|
||||
"hit_count": 5,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
}
|
||||
|
||||
entry = redis_backend._dict_to_entry(data)
|
||||
assert entry.prompt == "test"
|
||||
assert entry.namespace == "test_ns"
|
||||
assert entry.hit_count == 5
|
||||
|
||||
def test_dict_to_entry_defaults(self, redis_backend):
|
||||
"""Test dict_to_entry uses defaults for missing fields."""
|
||||
data = {
|
||||
"prompt": "test",
|
||||
"response": "response",
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
entry = redis_backend._dict_to_entry(data)
|
||||
assert entry.embedding is None
|
||||
assert entry.ttl is None
|
||||
assert entry.namespace == "default"
|
||||
assert entry.hit_count == 0
|
||||
assert entry.input_tokens == 0
|
||||
assert entry.output_tokens == 0
|
||||
|
||||
def test_connection_failure(self):
|
||||
"""Test connection failure raises CacheBackendError."""
|
||||
from semantic_llm_cache.backends.redis import RedisBackend
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
# Need to patch both the import and the from_url call
|
||||
with patch("semantic_llm_cache.backends.redis.redis_lib") as mock_redis:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.side_effect = Exception("Connection refused")
|
||||
mock_redis.from_url.return_value = mock_client
|
||||
|
||||
with pytest.raises(CacheBackendError, match="Failed to connect"):
|
||||
RedisBackend(url="redis://localhost:9999/0")
|
||||
|
||||
def test_set_with_ttl(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test setting entry with TTL."""
|
||||
sample_entry.ttl = 3600
|
||||
redis_backend.set("key1", sample_entry)
|
||||
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args[1]["ex"] == 3600
|
||||
|
||||
def test_get_expired_entry(self, redis_backend, mock_redis):
|
||||
"""Test expired entry is not returned."""
|
||||
expired_dict = {
|
||||
"prompt": "test",
|
||||
"response": "response",
|
||||
"embedding": None,
|
||||
"created_at": time.time() - 1000,
|
||||
"ttl": 100, # 100 seconds TTL, created 1000 seconds ago
|
||||
"namespace": "default",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
|
||||
mock_redis.get.return_value = json.dumps(expired_dict).encode()
|
||||
mock_redis.delete.return_value = 1
|
||||
|
||||
result = redis_backend.get("expired_key")
|
||||
assert result is None
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
def test_close(self, redis_backend, mock_redis):
|
||||
"""Test closing Redis connection."""
|
||||
redis_backend.close()
|
||||
mock_redis.close.assert_called_once()
|
||||
|
||||
def test_close_error_handling(self, redis_backend, mock_redis):
|
||||
"""Test close handles errors gracefully."""
|
||||
mock_redis.close.side_effect = Exception("Close error")
|
||||
# Should not raise
|
||||
redis_backend.close()
|
||||
|
||||
def test_iterate_with_expired_entries(self, redis_backend, mock_redis):
|
||||
"""Test iterate filters out expired entries."""
|
||||
expired_dict = {
|
||||
"prompt": "expired",
|
||||
"response": "response",
|
||||
"embedding": None,
|
||||
"created_at": time.time() - 1000,
|
||||
"ttl": 100,
|
||||
"namespace": "default",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
valid_dict = {
|
||||
"prompt": "valid",
|
||||
"response": "response",
|
||||
"embedding": None,
|
||||
"created_at": time.time(),
|
||||
"ttl": None,
|
||||
"namespace": "default",
|
||||
"hit_count": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_get(key):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return json.dumps(expired_dict).encode()
|
||||
return json.dumps(valid_dict).encode()
|
||||
|
||||
mock_redis.keys.return_value = [b"semantic_llm_cache:expired", b"semantic_llm_cache:valid"]
|
||||
mock_redis.get.side_effect = mock_get
|
||||
|
||||
results = redis_backend.iterate()
|
||||
# Only valid entry should be returned
|
||||
assert len(results) == 1
|
||||
assert results[0][1].prompt == "valid"
|
||||
|
||||
def test_set_error_handling(self, redis_backend, mock_redis, sample_entry):
|
||||
"""Test set handles Redis errors."""
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
mock_redis.set.side_effect = Exception("Redis error")
|
||||
|
||||
with pytest.raises(CacheBackendError, match="Failed to set"):
|
||||
redis_backend.set("key1", sample_entry)
|
||||
|
||||
def test_delete_error_handling(self, redis_backend, mock_redis):
|
||||
"""Test delete handles Redis errors."""
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
mock_redis.delete.side_effect = Exception("Redis error")
|
||||
|
||||
with pytest.raises(CacheBackendError, match="Failed to delete"):
|
||||
redis_backend.delete("key1")
|
||||
|
||||
def test_iterate_error_handling(self, redis_backend, mock_redis):
|
||||
"""Test iterate handles Redis errors."""
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
mock_redis.keys.side_effect = Exception("Redis error")
|
||||
|
||||
with pytest.raises(CacheBackendError, match="Failed to iterate"):
|
||||
redis_backend.iterate()
|
||||
|
||||
def test_get_json_error(self, redis_backend, mock_redis):
|
||||
"""Test get handles invalid JSON."""
|
||||
import json
|
||||
mock_redis.get.return_value = b"invalid json"
|
||||
|
||||
# The JSON decode error should be wrapped in CacheBackendError
|
||||
with pytest.raises((CacheBackendError, json.JSONDecodeError)):
|
||||
redis_backend.get("key1")
|
||||
|
||||
def test_import_error_without_package(self):
|
||||
"""Test ImportError when redis package not installed."""
|
||||
# This test validates the import guard in redis.py
|
||||
from semantic_llm_cache.backends import redis as redis_module
|
||||
|
||||
# Check that RedisBackend is defined
|
||||
assert hasattr(redis_module, "RedisBackend")
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry dataclass."""
|
||||
|
||||
def test_is_expired_with_none_ttl(self):
|
||||
"""Test entry with None TTL never expires."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=None,
|
||||
created_at=time.time() - 10000,
|
||||
)
|
||||
assert not entry.is_expired(time.time())
|
||||
|
||||
def test_is_expired_with_ttl(self):
|
||||
"""Test entry with TTL expires correctly."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=10,
|
||||
created_at=time.time() - 15,
|
||||
)
|
||||
assert entry.is_expired(time.time())
|
||||
|
||||
def test_is_expired_not_yet(self):
|
||||
"""Test entry not yet expired."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=10,
|
||||
created_at=time.time() - 5,
|
||||
)
|
||||
assert not entry.is_expired(time.time())
|
||||
|
||||
def test_estimate_cost(self):
|
||||
"""Test cost estimation."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
)
|
||||
cost = entry.estimate_cost(0.001, 0.002)
|
||||
# 1000/1000 * 0.001 + 500/1000 * 0.002 = 0.001 + 0.001 = 0.002
|
||||
assert cost == pytest.approx(0.002)
|
||||
606
tests/test_core.py
Normal file
606
tests/test_core.py
Normal file
|
|
@ -0,0 +1,606 @@
|
|||
"""Tests for core cache decorator and API."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache import CacheContext, CachedLLM, cache, get_default_backend, set_default_backend
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.config import CacheConfig, CacheEntry
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
|
||||
class TestCacheDecorator:
|
||||
"""Tests for @cache decorator."""
|
||||
|
||||
def test_exact_match_cache_hit(self, mock_llm_func):
|
||||
"""Test exact match caching returns cached result."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
# First call - cache miss
|
||||
result1 = cached_func("What is Python?")
|
||||
assert result1 == "Python is a programming language."
|
||||
assert call_count["count"] == 1
|
||||
|
||||
# Second call - cache hit
|
||||
result2 = cached_func("What is Python?")
|
||||
assert result2 == "Python is a programming language."
|
||||
assert call_count["count"] == 1 # No additional call
|
||||
|
||||
def test_cache_miss_different_prompt(self, mock_llm_func):
|
||||
"""Test different prompts result in cache misses."""
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> str:
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
result1 = cached_func("What is Python?")
|
||||
result2 = cached_func("What is Rust?")
|
||||
|
||||
assert result1 == "Python is a programming language."
|
||||
assert result2 == "Rust is a systems programming language."
|
||||
|
||||
def test_cache_disabled(self, mock_llm_func):
|
||||
"""Test caching can be disabled."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(enabled=False)
|
||||
def cached_func(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
cached_func("What is Python?")
|
||||
cached_func("What is Python?")
|
||||
|
||||
assert call_count["count"] == 2 # Both calls hit the function
|
||||
|
||||
def test_custom_namespace(self, mock_llm_func):
|
||||
"""Test custom namespace isolates cache."""
|
||||
@cache(namespace="test")
|
||||
def cached_func(prompt: str) -> str:
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
result = cached_func("What is Python?")
|
||||
assert result == "Python is a programming language."
|
||||
|
||||
def test_ttl_expiration(self, mock_llm_func):
|
||||
"""Test TTL expiration works."""
|
||||
@cache(ttl=1) # 1 second TTL
|
||||
def cached_func(prompt: str) -> str:
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
# First call
|
||||
result1 = cached_func("What is Python?")
|
||||
assert result1 == "Python is a programming language."
|
||||
|
||||
# Immediate second call - should hit cache
|
||||
result2 = cached_func("What is Python?")
|
||||
assert result2 == "Python is a programming language."
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.1)
|
||||
|
||||
# Third call - should miss due to TTL
|
||||
# Note: This test may be flaky in slow CI environments
|
||||
result3 = cached_func("What is Python?")
|
||||
assert result3 == "Python is a programming language."
|
||||
|
||||
def test_cache_with_exception(self):
|
||||
"""Test cache handles exceptions properly."""
|
||||
@cache()
|
||||
def failing_func(prompt: str) -> str:
|
||||
raise ValueError("LLM API error")
|
||||
|
||||
with pytest.raises(PromptCacheError):
|
||||
failing_func("test prompt")
|
||||
|
||||
def test_semantic_similarity_match(self, mock_llm_func):
|
||||
"""Test semantic similarity matching."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(similarity=0.85)
|
||||
def cached_func(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
# First call
|
||||
cached_func("What is Python?")
|
||||
assert call_count["count"] == 1
|
||||
|
||||
# Similar prompts may hit cache depending on embedding
|
||||
# Note: With dummy embeddings, exact string matching determines similarity
|
||||
cached_func("What is Python?") # Exact match
|
||||
assert call_count["count"] == 1
|
||||
|
||||
|
||||
class TestCacheContext:
|
||||
"""Tests for CacheContext manager."""
|
||||
|
||||
def test_context_manager(self):
|
||||
"""Test CacheContext works as context manager."""
|
||||
with CacheContext(similarity=0.9, ttl=1800) as ctx:
|
||||
assert ctx.config.similarity_threshold == 0.9
|
||||
assert ctx.config.ttl == 1800
|
||||
|
||||
def test_context_stats(self):
|
||||
"""Test context tracks stats."""
|
||||
with CacheContext() as ctx:
|
||||
stats = ctx.stats
|
||||
assert "hits" in stats
|
||||
assert "misses" in stats
|
||||
|
||||
|
||||
class TestCachedLLM:
|
||||
"""Tests for CachedLLM wrapper class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test CachedLLM initialization."""
|
||||
llm = CachedLLM(provider="openai", model="gpt-4")
|
||||
assert llm._provider == "openai"
|
||||
assert llm._model == "gpt-4"
|
||||
|
||||
def test_chat_with_llm_func(self):
|
||||
"""Test chat method with custom LLM function."""
|
||||
llm = CachedLLM()
|
||||
|
||||
def mock_llm(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
result = llm.chat("Hello", llm_func=mock_llm)
|
||||
assert result == "Response to: Hello"
|
||||
|
||||
def test_chat_caches_responses(self, mock_llm_func):
|
||||
"""Test chat caches responses."""
|
||||
llm = CachedLLM()
|
||||
call_count = {"count": 0}
|
||||
|
||||
def counting_llm(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
llm.chat("What is Python?", llm_func=counting_llm)
|
||||
llm.chat("What is Python?", llm_func=counting_llm)
|
||||
|
||||
# Should cache (depends on embedding, may not with dummy)
|
||||
assert call_count["count"] >= 1
|
||||
|
||||
|
||||
class TestBackendManagement:
|
||||
"""Tests for backend management functions."""
|
||||
|
||||
def test_get_default_backend(self):
|
||||
"""Test get_default_backend returns a backend."""
|
||||
backend = get_default_backend()
|
||||
assert backend is not None
|
||||
assert isinstance(backend, MemoryBackend)
|
||||
|
||||
def test_set_default_backend(self):
|
||||
"""Test set_default_backend changes default."""
|
||||
custom_backend = MemoryBackend(max_size=10)
|
||||
set_default_backend(custom_backend)
|
||||
|
||||
backend = get_default_backend()
|
||||
assert backend is custom_backend
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry class."""
|
||||
|
||||
def test_entry_creation(self):
|
||||
"""Test CacheEntry creation."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
created_at=time.time(),
|
||||
)
|
||||
assert entry.prompt == "test"
|
||||
assert entry.response == "response"
|
||||
|
||||
def test_is_expired_no_ttl(self):
|
||||
"""Test entry without TTL never expires."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=None,
|
||||
created_at=time.time() - 1000,
|
||||
)
|
||||
assert not entry.is_expired(time.time())
|
||||
|
||||
def test_is_expired_with_ttl(self):
|
||||
"""Test entry with TTL expires correctly."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=1, # 1 second
|
||||
created_at=time.time() - 2, # 2 seconds ago
|
||||
)
|
||||
assert entry.is_expired(time.time())
|
||||
|
||||
def test_estimate_cost(self):
|
||||
"""Test cost estimation."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
)
|
||||
cost = entry.estimate_cost(0.001, 0.002)
|
||||
# 100/1000 * 0.001 + 50/1000 * 0.002 = 0.0001 + 0.0001 = 0.0002
|
||||
assert abs(cost - 0.0002) < 1e-6
|
||||
|
||||
|
||||
class TestCacheConfig:
|
||||
"""Tests for CacheConfig class."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = CacheConfig()
|
||||
assert config.similarity_threshold == 1.0
|
||||
assert config.ttl == 3600
|
||||
assert config.namespace == "default"
|
||||
assert config.enabled is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = CacheConfig(
|
||||
similarity_threshold=0.9,
|
||||
ttl=7200,
|
||||
namespace="custom",
|
||||
enabled=False,
|
||||
)
|
||||
assert config.similarity_threshold == 0.9
|
||||
assert config.ttl == 7200
|
||||
assert config.namespace == "custom"
|
||||
assert config.enabled is False
|
||||
|
||||
def test_invalid_similarity(self):
|
||||
"""Test invalid similarity raises error."""
|
||||
with pytest.raises(ValueError, match="similarity_threshold"):
|
||||
CacheConfig(similarity_threshold=1.5)
|
||||
|
||||
def test_invalid_ttl(self):
|
||||
"""Test invalid TTL raises error."""
|
||||
with pytest.raises(ValueError, match="ttl"):
|
||||
CacheConfig(ttl=-1)
|
||||
|
||||
def test_invalid_max_size(self):
|
||||
"""Test invalid max_size raises error."""
|
||||
with pytest.raises(ValueError, match="max_cache_size"):
|
||||
CacheConfig(max_cache_size=0)
|
||||
|
||||
|
||||
class TestCacheDecoratorEdgeCases:
|
||||
"""Tests for edge cases in cache decorator."""
|
||||
|
||||
def test_cache_with_kwargs_only(self, mock_llm_func):
|
||||
"""Test caching when function is called with kwargs only."""
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> str:
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
result = cached_func(prompt="What is Python?")
|
||||
assert result == "Python is a programming language."
|
||||
|
||||
def test_cache_with_multiple_args(self):
|
||||
"""Test caching with multiple arguments."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str, temperature: float = 0.7) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response to: {prompt} at {temperature}"
|
||||
|
||||
cached_func("test", 0.5)
|
||||
cached_func("test", 0.5)
|
||||
# Same prompt hits cache even with different temperature
|
||||
# (cache key is based on first arg)
|
||||
cached_func("test", 0.9)
|
||||
cached_func("different", 0.5)
|
||||
|
||||
# First call + different prompt = 2 calls
|
||||
assert call_count["count"] == 2
|
||||
|
||||
def test_cache_with_custom_key_func(self):
|
||||
"""Test custom key function."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
def custom_key(prompt: str, temperature: float = 0.7) -> str:
|
||||
return f"{prompt}:{temperature}"
|
||||
|
||||
@cache(key_func=custom_key)
|
||||
def cached_func(prompt: str, temperature: float = 0.7) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response to: {prompt} at {temperature}"
|
||||
|
||||
cached_func("test", 0.7)
|
||||
cached_func("test", 0.7)
|
||||
assert call_count["count"] == 1
|
||||
|
||||
def test_semantic_match_threshold_edge(self, mock_llm_func):
|
||||
"""Test semantic matching at threshold boundaries."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(similarity=0.5) # Lower threshold
|
||||
def cached_func(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return mock_llm_func(prompt)
|
||||
|
||||
# First call
|
||||
cached_func("What is Python?")
|
||||
# Similar query may hit with lower threshold
|
||||
cached_func("What is Python?") # Exact match always hits
|
||||
assert call_count["count"] == 1
|
||||
|
||||
def test_cache_with_none_response(self):
|
||||
"""Test caching None responses."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> None:
|
||||
call_count["count"] += 1
|
||||
return None
|
||||
|
||||
result1 = cached_func("test")
|
||||
result2 = cached_func("test")
|
||||
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
assert call_count["count"] == 1 # Should cache
|
||||
|
||||
def test_cache_with_empty_string_response(self):
|
||||
"""Test caching empty string responses."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return ""
|
||||
|
||||
result1 = cached_func("test")
|
||||
result2 = cached_func("test")
|
||||
|
||||
assert result1 == ""
|
||||
assert result2 == ""
|
||||
assert call_count["count"] == 1 # Should cache
|
||||
|
||||
def test_cache_with_dict_response(self):
|
||||
"""Test caching dict responses."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> dict:
|
||||
call_count["count"] += 1
|
||||
return {"key": "value", "number": 42}
|
||||
|
||||
result1 = cached_func("test")
|
||||
cached_func("test") # Second call to verify caching
|
||||
|
||||
assert result1 == {"key": "value", "number": 42}
|
||||
assert call_count["count"] == 1
|
||||
|
||||
def test_cache_with_list_response(self):
|
||||
"""Test caching list responses."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache()
|
||||
def cached_func(prompt: str) -> list:
|
||||
call_count["count"] += 1
|
||||
return [1, 2, 3, 4, 5]
|
||||
|
||||
result1 = cached_func("test")
|
||||
cached_func("test") # Second call to verify caching
|
||||
|
||||
assert result1 == [1, 2, 3, 4, 5]
|
||||
assert call_count["count"] == 1
|
||||
|
||||
|
||||
class TestCacheDecoratorErrorPaths:
|
||||
"""Tests for error handling in cache decorator."""
|
||||
|
||||
def test_backend_set_raises_error(self):
|
||||
"""Test that backend.set errors propagate."""
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
backend = MagicMock()
|
||||
# Backend set wraps exceptions in CacheBackendError
|
||||
backend.set.side_effect = CacheBackendError("Storage error")
|
||||
backend.get.return_value = None # No cached value
|
||||
|
||||
@cache(backend=backend)
|
||||
def cached_func(prompt: str) -> str:
|
||||
return "response"
|
||||
|
||||
# The CacheBackendError from backend.set should propagate
|
||||
with pytest.raises(CacheBackendError, match="Storage error"):
|
||||
cached_func("test")
|
||||
|
||||
def test_backend_get_raises_error(self):
|
||||
"""Test that backend.get errors propagate."""
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
backend = MagicMock()
|
||||
# Backend get wraps exceptions in CacheBackendError
|
||||
backend.get.side_effect = CacheBackendError("Get error")
|
||||
|
||||
@cache(backend=backend)
|
||||
def cached_func(prompt: str) -> str:
|
||||
return "response"
|
||||
|
||||
# The CacheBackendError from backend.get should propagate
|
||||
with pytest.raises(CacheBackendError, match="Get error"):
|
||||
cached_func("test")
|
||||
|
||||
def test_llm_error_still_wrapped(self):
|
||||
"""Test that LLM errors are still wrapped in PromptCacheError."""
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
@cache()
|
||||
def failing_func(prompt: str) -> str:
|
||||
raise ValueError("LLM API error")
|
||||
|
||||
with pytest.raises(PromptCacheError):
|
||||
failing_func("test")
|
||||
|
||||
|
||||
class TestCacheContextAdvanced:
|
||||
"""Advanced tests for CacheContext."""
|
||||
|
||||
def test_context_with_zero_similarity(self):
|
||||
"""Test context with zero similarity (accept all)."""
|
||||
with CacheContext(similarity=0.0) as ctx:
|
||||
assert ctx.config.similarity_threshold == 0.0
|
||||
|
||||
def test_context_with_infinite_ttl(self):
|
||||
"""Test context with infinite TTL (None)."""
|
||||
with CacheContext(ttl=None) as ctx:
|
||||
assert ctx.config.ttl is None
|
||||
|
||||
def test_context_disabled(self):
|
||||
"""Test disabled context."""
|
||||
with CacheContext(enabled=False) as ctx:
|
||||
assert ctx.config.enabled is False
|
||||
|
||||
|
||||
class TestCachedLLMAdvanced:
|
||||
"""Advanced tests for CachedLLM."""
|
||||
|
||||
def test_chat_with_kwargs(self):
|
||||
"""Test chat with additional kwargs passed to LLM."""
|
||||
llm = CachedLLM()
|
||||
|
||||
def mock_llm(prompt: str, temperature: float = 0.7, max_tokens: int = 100) -> str:
|
||||
return f"Response to: {prompt} (temp={temperature}, tokens={max_tokens})"
|
||||
|
||||
result = llm.chat("Hello", llm_func=mock_llm, temperature=0.5, max_tokens=200)
|
||||
assert "temp=0.5" in result
|
||||
assert "tokens=200" in result
|
||||
|
||||
def test_cached_llm_with_custom_backend(self):
|
||||
"""Test CachedLLM with custom backend."""
|
||||
custom_backend = MemoryBackend(max_size=5)
|
||||
llm = CachedLLM(backend=custom_backend)
|
||||
|
||||
assert llm._backend is custom_backend
|
||||
|
||||
def test_cached_llm_different_namespaces(self):
|
||||
"""Test CachedLLM with different namespaces."""
|
||||
llm1 = CachedLLM(namespace="ns1")
|
||||
llm2 = CachedLLM(namespace="ns2")
|
||||
|
||||
assert llm1._config.namespace == "ns1"
|
||||
assert llm2._config.namespace == "ns2"
|
||||
|
||||
|
||||
class TestBackendManagementAdvanced:
|
||||
"""Advanced tests for backend management."""
|
||||
|
||||
def test_multiple_default_backend_changes(self):
|
||||
"""Test changing default backend multiple times."""
|
||||
backend1 = MemoryBackend(max_size=10)
|
||||
backend2 = MemoryBackend(max_size=20)
|
||||
backend3 = MemoryBackend(max_size=30)
|
||||
|
||||
set_default_backend(backend1)
|
||||
assert get_default_backend() is backend1
|
||||
|
||||
set_default_backend(backend2)
|
||||
assert get_default_backend() is backend2
|
||||
|
||||
set_default_backend(backend3)
|
||||
assert get_default_backend() is backend3
|
||||
|
||||
def test_backend_persists_stats(self):
|
||||
"""Test backend stats persist across get_default_backend calls."""
|
||||
backend = get_default_backend()
|
||||
|
||||
# Create an entry
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
created_at=time.time()
|
||||
)
|
||||
backend.set("key1", entry)
|
||||
|
||||
# Get stats
|
||||
stats = backend.get_stats()
|
||||
assert stats["size"] == 1
|
||||
|
||||
|
||||
class TestCacheEntryEdgeCases:
|
||||
"""Edge case tests for CacheEntry."""
|
||||
|
||||
def test_entry_with_zero_tokens(self):
|
||||
"""Test entry with zero token counts."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
)
|
||||
cost = entry.estimate_cost(0.001, 0.002)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_entry_with_large_token_count(self):
|
||||
"""Test entry with large token counts."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
input_tokens=100000,
|
||||
output_tokens=50000,
|
||||
)
|
||||
cost = entry.estimate_cost(0.001, 0.002)
|
||||
assert cost > 0
|
||||
|
||||
def test_entry_with_negative_ttl(self):
|
||||
"""Test entry creation handles negative TTL (becomes expired)."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
ttl=-1,
|
||||
created_at=time.time(),
|
||||
)
|
||||
# Negative TTL means immediately expired
|
||||
assert entry.is_expired(time.time())
|
||||
|
||||
def test_entry_hit_count_initialization(self):
|
||||
"""Test entry initializes with zero hit count."""
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
created_at=time.time(),
|
||||
)
|
||||
assert entry.hit_count == 0
|
||||
|
||||
|
||||
class TestCacheConfigEdgeCases:
|
||||
"""Edge case tests for CacheConfig."""
|
||||
|
||||
def test_config_boundary_similarity(self):
|
||||
"""Test similarity at valid boundaries."""
|
||||
config1 = CacheConfig(similarity_threshold=0.0)
|
||||
assert config1.similarity_threshold == 0.0
|
||||
|
||||
config2 = CacheConfig(similarity_threshold=1.0)
|
||||
assert config2.similarity_threshold == 1.0
|
||||
|
||||
def test_config_zero_ttl(self):
|
||||
"""Test zero TTL is rejected (validation requires positive)."""
|
||||
# The validation in CacheConfig rejects ttl <= 0
|
||||
with pytest.raises(ValueError, match="ttl"):
|
||||
CacheConfig(ttl=0)
|
||||
|
||||
def test_config_very_large_ttl(self):
|
||||
"""Test very large TTL."""
|
||||
config = CacheConfig(ttl=86400 * 365) # 1 year
|
||||
assert config.ttl == 86400 * 365
|
||||
|
||||
def test_config_with_special_namespace(self):
|
||||
"""Test namespace with special characters."""
|
||||
config = CacheConfig(namespace="test-ns_123.v1")
|
||||
assert config.namespace == "test-ns_123.v1"
|
||||
330
tests/test_integration.py
Normal file
330
tests/test_integration.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
"""Integration tests for prompt-cache."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache import cache, clear_cache, get_stats, invalidate
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end integration tests."""
|
||||
|
||||
def test_full_cache_workflow(self):
|
||||
"""Test complete cache workflow from hit to miss."""
|
||||
backend = MemoryBackend()
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
# First call - miss
|
||||
result1 = llm_function("What is Python?")
|
||||
assert result1 == "Response to: What is Python?"
|
||||
assert call_count["count"] == 1
|
||||
|
||||
# Second call - hit
|
||||
result2 = llm_function("What is Python?")
|
||||
assert result2 == "Response to: What is Python?"
|
||||
assert call_count["count"] == 1
|
||||
|
||||
# Different prompt - miss
|
||||
result3 = llm_function("What is Rust?")
|
||||
assert result3 == "Response to: What is Rust?"
|
||||
assert call_count["count"] == 2
|
||||
|
||||
def test_stats_integration(self):
|
||||
"""Test statistics tracking."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend, namespace="test")
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
# Generate some activity
|
||||
llm_function("prompt 1")
|
||||
llm_function("prompt 1") # Hit
|
||||
llm_function("prompt 2")
|
||||
|
||||
stats = get_stats(namespace="test")
|
||||
assert stats["total_requests"] >= 2
|
||||
|
||||
def test_clear_cache_integration(self):
|
||||
"""Test clearing cache affects function behavior."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
llm_function("test prompt")
|
||||
|
||||
# Clear cache
|
||||
cleared = clear_cache()
|
||||
assert cleared >= 0
|
||||
|
||||
# Function should still work
|
||||
result = llm_function("test prompt")
|
||||
assert result == "Response to: test prompt"
|
||||
|
||||
def test_invalidate_integration(self):
|
||||
"""Test invalidating cache entries."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
llm_function("Python programming")
|
||||
llm_function("Rust programming")
|
||||
|
||||
# Invalidate Python entries
|
||||
count = invalidate("Python")
|
||||
assert count >= 0
|
||||
|
||||
def test_multiple_namespaces(self):
|
||||
"""Test cache isolation across namespaces."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend, namespace="app1")
|
||||
def app1_llm(prompt: str) -> str:
|
||||
return f"App1: {prompt}"
|
||||
|
||||
@cache(backend=backend, namespace="app2")
|
||||
def app2_llm(prompt: str) -> str:
|
||||
return f"App2: {prompt}"
|
||||
|
||||
result1 = app1_llm("test")
|
||||
result2 = app2_llm("test")
|
||||
|
||||
assert result1 == "App1: test"
|
||||
assert result2 == "App2: test"
|
||||
|
||||
def test_ttl_expiration_integration(self):
|
||||
"""Test TTL expiration in real workflow."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend, ttl=1) # 1 second TTL
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
llm_function("test prompt")
|
||||
|
||||
# Immediate second call - hit
|
||||
llm_function("test prompt")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.5)
|
||||
|
||||
# Should miss (cached entry expired)
|
||||
llm_function("test prompt")
|
||||
|
||||
|
||||
class TestComplexScenarios:
|
||||
"""Tests for complex real-world scenarios."""
|
||||
|
||||
def test_high_volume_caching(self):
|
||||
"""Test cache behavior with many entries."""
|
||||
backend = MemoryBackend(max_size=100)
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response {call_count['count']}"
|
||||
|
||||
# Add many entries
|
||||
for i in range(150):
|
||||
llm_function(f"prompt {i}")
|
||||
|
||||
# Some entries should have been evicted
|
||||
stats = backend.get_stats()
|
||||
assert stats["size"] <= 100
|
||||
|
||||
def test_concurrent_like_access(self):
|
||||
"""Test multiple calls to same cached entry."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Unique: {time.time()}"
|
||||
|
||||
# Multiple calls
|
||||
results = [llm_function("test") for _ in range(5)]
|
||||
|
||||
# All should return same result (cached)
|
||||
assert len(set(results)) == 1
|
||||
|
||||
def test_different_return_types(self):
|
||||
"""Test caching different return types."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def return_dict(prompt: str) -> dict:
|
||||
return {"key": "value"}
|
||||
|
||||
@cache(backend=backend)
|
||||
def return_list(prompt: str) -> list:
|
||||
return [1, 2, 3]
|
||||
|
||||
@cache(backend=backend)
|
||||
def return_string(prompt: str) -> str:
|
||||
return "string response"
|
||||
|
||||
# Use unique prompts to avoid cache collision
|
||||
assert isinstance(return_dict("test_dict"), dict)
|
||||
assert isinstance(return_list("test_list"), list)
|
||||
assert isinstance(return_string("test_string"), str)
|
||||
|
||||
def test_empty_and_none_responses(self):
|
||||
"""Test caching empty and None responses."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def return_empty(prompt: str) -> str:
|
||||
return ""
|
||||
|
||||
@cache(backend=backend)
|
||||
def return_none(prompt: str) -> None:
|
||||
return None
|
||||
|
||||
assert return_empty("empty_test") == ""
|
||||
assert return_none("none_test") is None
|
||||
|
||||
# Should still cache (second calls should hit cache)
|
||||
assert return_empty("empty_test") == ""
|
||||
assert return_none("none_test") is None
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various scenarios."""
|
||||
|
||||
def test_function_with_exception(self):
|
||||
"""Test function that raises exception."""
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def failing_function(prompt: str) -> str:
|
||||
if "error" in prompt:
|
||||
raise ValueError("Test error")
|
||||
return "OK"
|
||||
|
||||
# Normal call works
|
||||
assert failing_function("normal") == "OK"
|
||||
|
||||
# Error call raises PromptCacheError (wrapped exception)
|
||||
with pytest.raises(PromptCacheError):
|
||||
failing_function("error prompt")
|
||||
|
||||
# Normal call still works
|
||||
assert failing_function("normal") == "OK"
|
||||
|
||||
def test_backend_error_handling(self):
|
||||
"""Test that backend wraps errors properly."""
|
||||
from semantic_llm_cache.backends.memory import MemoryBackend
|
||||
|
||||
# Use MemoryBackend which has proper error handling
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend)
|
||||
def working_func(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
# Normal operation works
|
||||
assert working_func("test") == "Response to: test"
|
||||
|
||||
# Second call hits cache
|
||||
assert working_func("test") == "Response to: test"
|
||||
|
||||
# Backend properly stores and retrieves entries
|
||||
stats = backend.get_stats()
|
||||
assert stats["hits"] >= 1
|
||||
|
||||
|
||||
class TestPromptNormalization:
|
||||
"""Tests for prompt normalization effects."""
|
||||
|
||||
def test_whitespace_normalization(self):
|
||||
"""Test prompts with different whitespace are cached separately."""
|
||||
backend = MemoryBackend()
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response: {prompt}"
|
||||
|
||||
llm_function("What is Python?")
|
||||
llm_function("What is Python?") # Extra spaces
|
||||
|
||||
# Normalization should make these the same
|
||||
# Note: This depends on the normalization implementation
|
||||
assert call_count["count"] >= 1
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test case sensitivity in caching."""
|
||||
backend = MemoryBackend()
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(backend=backend)
|
||||
def llm_function(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response: {prompt}"
|
||||
|
||||
llm_function("What is Python?")
|
||||
llm_function("what is python?")
|
||||
|
||||
# Case differences create different cache entries
|
||||
# (normalization doesn't lowercase by default)
|
||||
assert call_count["count"] >= 1
|
||||
|
||||
|
||||
class TestConfigurationCombinations:
|
||||
"""Tests for various configuration combinations."""
|
||||
|
||||
def test_no_caching_config(self):
|
||||
"""Test configuration with caching disabled."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend, enabled=False)
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response: {time.time()}"
|
||||
|
||||
result1 = llm_function("test")
|
||||
time.sleep(0.01)
|
||||
result2 = llm_function("test")
|
||||
|
||||
# Without caching, results differ
|
||||
assert result1 != result2
|
||||
|
||||
def test_zero_ttl(self):
|
||||
"""Test zero TTL means immediate expiration."""
|
||||
backend = MemoryBackend()
|
||||
|
||||
@cache(backend=backend, ttl=0)
|
||||
def llm_function(prompt: str) -> str:
|
||||
return f"Response: {prompt}"
|
||||
|
||||
llm_function("test")
|
||||
# Entry immediately expires, so next call is a miss
|
||||
llm_function("test")
|
||||
|
||||
def test_infinite_ttl(self):
|
||||
"""Test None TTL means never expire."""
|
||||
backend = MemoryBackend()
|
||||
call_count = {"count": 0}
|
||||
|
||||
@cache(backend=backend, ttl=None)
|
||||
def llm_function(prompt: str) -> str:
|
||||
call_count["count"] += 1
|
||||
return f"Response: {prompt}"
|
||||
|
||||
llm_function("test")
|
||||
llm_function("test")
|
||||
|
||||
assert call_count["count"] == 1
|
||||
208
tests/test_similarity.py
Normal file
208
tests/test_similarity.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""Tests for embedding generation and similarity matching."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache.similarity import (
|
||||
DummyEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
OpenAIEmbeddingProvider,
|
||||
SentenceTransformerProvider,
|
||||
cosine_similarity,
|
||||
create_embedding_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
"""Tests for cosine_similarity function."""
|
||||
|
||||
def test_identical_vectors(self):
|
||||
"""Test identical vectors have similarity 1.0."""
|
||||
a = [1.0, 2.0, 3.0]
|
||||
b = [1.0, 2.0, 3.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
"""Test orthogonal vectors have similarity 0.0."""
|
||||
a = [1.0, 0.0, 0.0]
|
||||
b = [0.0, 1.0, 0.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(0.0)
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
"""Test opposite vectors have similarity -1.0."""
|
||||
a = [1.0, 2.0, 3.0]
|
||||
b = [-1.0, -2.0, -3.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||
|
||||
def test_zero_vectors(self):
|
||||
"""Test zero vectors return 0.0."""
|
||||
a = [0.0, 0.0, 0.0]
|
||||
b = [1.0, 2.0, 3.0]
|
||||
assert cosine_similarity(a, b) == 0.0
|
||||
|
||||
def test_numpy_array_input(self):
|
||||
"""Test function accepts numpy arrays."""
|
||||
a = np.array([1.0, 2.0, 3.0])
|
||||
b = np.array([1.0, 2.0, 3.0])
|
||||
assert cosine_similarity(a, b) == pytest.approx(1.0)
|
||||
|
||||
def test_mixed_dimensions(self):
|
||||
"""Test vectors of different dimensions raise error."""
|
||||
a = [1.0, 2.0]
|
||||
b = [1.0, 2.0, 3.0]
|
||||
# Should raise ValueError for mismatched dimensions
|
||||
with pytest.raises(ValueError, match="dimension mismatch"):
|
||||
cosine_similarity(a, b)
|
||||
|
||||
|
||||
class TestDummyEmbeddingProvider:
|
||||
"""Tests for DummyEmbeddingProvider."""
|
||||
|
||||
def test_encode_returns_list(self):
|
||||
"""Test encode returns list of floats."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
embedding = provider.encode("test prompt")
|
||||
assert isinstance(embedding, list)
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
def test_encode_deterministic(self):
|
||||
"""Test same input produces same output."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
text = "test prompt"
|
||||
e1 = provider.encode(text)
|
||||
e2 = provider.encode(text)
|
||||
assert e1 == e2
|
||||
|
||||
def test_encode_different_inputs(self):
|
||||
"""Test different inputs produce different outputs."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
e1 = provider.encode("prompt 1")
|
||||
e2 = provider.encode("prompt 2")
|
||||
assert e1 != e2
|
||||
|
||||
def test_custom_dimension(self):
|
||||
"""Test custom embedding dimension."""
|
||||
provider = DummyEmbeddingProvider(dim=128)
|
||||
embedding = provider.encode("test")
|
||||
assert len(embedding) == 128
|
||||
|
||||
def test_embedding_normalized(self):
|
||||
"""Test embeddings are normalized to unit length."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
embedding = provider.encode("test prompt")
|
||||
# Calculate norm
|
||||
norm = np.linalg.norm(embedding)
|
||||
assert norm == pytest.approx(1.0, rel=1e-5)
|
||||
|
||||
|
||||
class TestSentenceTransformerProvider:
|
||||
"""Tests for SentenceTransformerProvider."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires sentence-transformers installation")
|
||||
def test_encode_returns_list(self):
|
||||
"""Test encode returns list of floats."""
|
||||
provider = SentenceTransformerProvider()
|
||||
embedding = provider.encode("test prompt")
|
||||
assert isinstance(embedding, list)
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
@pytest.mark.skip(reason="Requires sentence-transformers installation")
|
||||
def test_encode_deterministic(self):
|
||||
"""Test same input produces same output."""
|
||||
provider = SentenceTransformerProvider()
|
||||
text = "test prompt"
|
||||
e1 = provider.encode(text)
|
||||
e2 = provider.encode(text)
|
||||
assert e1 == e2
|
||||
|
||||
def test_import_error_without_package(self, monkeypatch):
|
||||
"""Test ImportError raised when package not installed."""
|
||||
# Skip if sentence-transformers is installed
|
||||
pytest.importorskip("sentence_transformers", reason="sentence-transformers is installed, cannot test import error")
|
||||
|
||||
|
||||
class TestOpenAIEmbeddingProvider:
|
||||
"""Tests for OpenAIEmbeddingProvider."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires OpenAI API key")
|
||||
def test_encode_returns_list(self):
|
||||
"""Test encode returns list of floats."""
|
||||
provider = OpenAIEmbeddingProvider()
|
||||
embedding = provider.encode("test prompt")
|
||||
assert isinstance(embedding, list)
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
@pytest.mark.skip(reason="Requires OpenAI API key")
|
||||
def test_encode_deterministic(self):
|
||||
"""Test same input produces same output."""
|
||||
provider = OpenAIEmbeddingProvider()
|
||||
text = "test prompt"
|
||||
e1 = provider.encode(text)
|
||||
e2 = provider.encode(text)
|
||||
# OpenAI embeddings may vary slightly
|
||||
assert len(e1) == len(e2)
|
||||
|
||||
def test_import_error_without_package(self, monkeypatch):
|
||||
"""Test ImportError raised when package not installed."""
|
||||
# Skip if openai is installed
|
||||
pytest.importorskip("openai", reason="openai is installed, cannot test import error")
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
def test_cache_provider(self):
|
||||
"""Test cache uses provider."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
cache = EmbeddingCache(provider=provider)
|
||||
|
||||
# Encode same text twice
|
||||
e1 = cache.encode("test prompt")
|
||||
e2 = cache.encode("test prompt")
|
||||
|
||||
assert e1 == e2
|
||||
|
||||
def test_cache_clear(self):
|
||||
"""Test cache can be cleared."""
|
||||
provider = DummyEmbeddingProvider()
|
||||
cache = EmbeddingCache(provider=provider)
|
||||
|
||||
cache.encode("test prompt")
|
||||
cache.clear_cache()
|
||||
|
||||
# Should still work after clear
|
||||
embedding = cache.encode("test prompt")
|
||||
assert len(embedding) > 0
|
||||
|
||||
def test_cache_default_provider(self):
|
||||
"""Test cache uses dummy provider by default."""
|
||||
cache = EmbeddingCache()
|
||||
embedding = cache.encode("test prompt")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
class TestCreateEmbeddingProvider:
|
||||
"""Tests for create_embedding_provider factory."""
|
||||
|
||||
def test_create_dummy_provider(self):
|
||||
"""Test creating dummy provider."""
|
||||
provider = create_embedding_provider("dummy")
|
||||
assert isinstance(provider, DummyEmbeddingProvider)
|
||||
|
||||
def test_create_auto_provider(self):
|
||||
"""Test auto provider creates sentence-transformers when available."""
|
||||
provider = create_embedding_provider("auto")
|
||||
# Creates SentenceTransformerProvider if available, else DummyEmbeddingProvider
|
||||
assert isinstance(provider, (DummyEmbeddingProvider, SentenceTransformerProvider))
|
||||
|
||||
def test_invalid_provider_type(self):
|
||||
"""Test invalid provider type raises error."""
|
||||
with pytest.raises(ValueError, match="Unknown provider type"):
|
||||
create_embedding_provider("invalid_type")
|
||||
|
||||
def test_custom_model_name(self, monkeypatch):
|
||||
"""Test custom model name is passed through."""
|
||||
# This would work with actual sentence-transformers
|
||||
provider = create_embedding_provider("dummy", model_name="custom-model")
|
||||
assert isinstance(provider, DummyEmbeddingProvider)
|
||||
443
tests/test_stats.py
Normal file
443
tests/test_stats.py
Normal file
|
|
@ -0,0 +1,443 @@
|
|||
"""Tests for statistics and analytics module."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.stats import (
|
||||
CacheStats,
|
||||
_stats_manager,
|
||||
clear_cache,
|
||||
export_cache,
|
||||
get_stats,
|
||||
invalidate,
|
||||
warm_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_stats_state():
|
||||
"""Clear stats state before each test."""
|
||||
_stats_manager.clear_stats()
|
||||
yield
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test CacheStats initializes with defaults."""
|
||||
stats = CacheStats()
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
assert stats.total_saved_ms == 0.0
|
||||
assert stats.estimated_savings_usd == 0.0
|
||||
|
||||
def test_hit_rate_empty(self):
|
||||
"""Test hit rate with no requests."""
|
||||
stats = CacheStats()
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_hit_rate_all_hits(self):
|
||||
"""Test hit rate with all cache hits."""
|
||||
stats = CacheStats(hits=10, misses=0)
|
||||
assert stats.hit_rate == 1.0
|
||||
|
||||
def test_hit_rate_all_misses(self):
|
||||
"""Test hit rate with all cache misses."""
|
||||
stats = CacheStats(hits=0, misses=10)
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_hit_rate_mixed(self):
|
||||
"""Test hit rate with mixed hits and misses."""
|
||||
stats = CacheStats(hits=7, misses=3)
|
||||
assert stats.hit_rate == 0.7
|
||||
|
||||
def test_total_requests(self):
|
||||
"""Test total requests calculation."""
|
||||
stats = CacheStats(hits=5, misses=3)
|
||||
assert stats.total_requests == 8
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting stats to dictionary."""
|
||||
stats = CacheStats(hits=10, misses=5, total_saved_ms=1000.0, estimated_savings_usd=0.5)
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["misses"] == 5
|
||||
assert result["hit_rate"] == 2/3
|
||||
assert result["total_requests"] == 15
|
||||
assert result["total_saved_ms"] == 1000.0
|
||||
assert result["estimated_savings_usd"] == 0.5
|
||||
|
||||
def test_iadd(self):
|
||||
"""Test in-place addition of stats."""
|
||||
stats1 = CacheStats(hits=5, misses=3, total_saved_ms=500.0, estimated_savings_usd=0.25)
|
||||
stats2 = CacheStats(hits=3, misses=2, total_saved_ms=300.0, estimated_savings_usd=0.15)
|
||||
|
||||
stats1 += stats2
|
||||
|
||||
assert stats1.hits == 8
|
||||
assert stats1.misses == 5
|
||||
assert stats1.total_saved_ms == 800.0
|
||||
assert stats1.estimated_savings_usd == 0.4
|
||||
|
||||
|
||||
class TestStatsManager:
|
||||
"""Tests for _StatsManager."""
|
||||
|
||||
def test_record_hit(self):
|
||||
"""Test recording a cache hit."""
|
||||
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0, saved_cost=0.01)
|
||||
|
||||
stats = _stats_manager.get_stats("test_ns")
|
||||
assert stats.hits == 1
|
||||
assert stats.total_saved_ms == 100.0
|
||||
assert stats.estimated_savings_usd == 0.01
|
||||
|
||||
def test_record_multiple_hits(self):
|
||||
"""Test recording multiple hits."""
|
||||
_stats_manager.record_hit("test_ns", latency_saved_ms=50.0, saved_cost=0.005)
|
||||
_stats_manager.record_hit("test_ns", latency_saved_ms=75.0, saved_cost=0.008)
|
||||
|
||||
stats = _stats_manager.get_stats("test_ns")
|
||||
assert stats.hits == 2
|
||||
assert stats.total_saved_ms == 125.0
|
||||
|
||||
def test_record_miss(self):
|
||||
"""Test recording a cache miss."""
|
||||
_stats_manager.record_miss("test_ns")
|
||||
|
||||
stats = _stats_manager.get_stats("test_ns")
|
||||
assert stats.misses == 1
|
||||
|
||||
def test_get_stats_namespace(self):
|
||||
"""Test getting stats for specific namespace."""
|
||||
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
||||
_stats_manager.record_miss("ns1")
|
||||
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
||||
|
||||
stats1 = _stats_manager.get_stats("ns1")
|
||||
stats2 = _stats_manager.get_stats("ns2")
|
||||
|
||||
assert stats1.hits == 1
|
||||
assert stats1.misses == 1
|
||||
assert stats2.hits == 1
|
||||
assert stats2.misses == 0
|
||||
|
||||
def test_get_stats_all_namespaces(self):
|
||||
"""Test getting aggregated stats for all namespaces."""
|
||||
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
||||
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
||||
_stats_manager.record_miss("ns1")
|
||||
|
||||
stats = _stats_manager.get_stats(None) # All namespaces
|
||||
assert stats.hits == 2
|
||||
assert stats.misses == 1
|
||||
|
||||
def test_get_stats_nonexistent_namespace(self):
|
||||
"""Test getting stats for namespace with no activity."""
|
||||
stats = _stats_manager.get_stats("nonexistent")
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
def test_clear_stats_namespace(self):
|
||||
"""Test clearing stats for specific namespace."""
|
||||
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
||||
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
||||
|
||||
_stats_manager.clear_stats("ns1")
|
||||
|
||||
stats1 = _stats_manager.get_stats("ns1")
|
||||
stats2 = _stats_manager.get_stats("ns2")
|
||||
|
||||
assert stats1.hits == 0
|
||||
assert stats2.hits == 1
|
||||
|
||||
def test_clear_stats_all(self):
|
||||
"""Test clearing all stats."""
|
||||
_stats_manager.record_hit("ns1", latency_saved_ms=100.0)
|
||||
_stats_manager.record_hit("ns2", latency_saved_ms=50.0)
|
||||
|
||||
_stats_manager.clear_stats()
|
||||
|
||||
stats = _stats_manager.get_stats(None)
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
def test_set_backend(self):
|
||||
"""Test setting default backend."""
|
||||
custom_backend = MemoryBackend(max_size=10)
|
||||
_stats_manager.set_backend(custom_backend)
|
||||
|
||||
retrieved = _stats_manager.get_backend()
|
||||
assert retrieved is custom_backend
|
||||
|
||||
|
||||
class TestPublicStatsAPI:
|
||||
"""Tests for public stats API functions."""
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test get_stats returns dictionary."""
|
||||
_stats_manager.record_hit("test_ns", latency_saved_ms=100.0)
|
||||
|
||||
stats = get_stats("test_ns")
|
||||
assert isinstance(stats, dict)
|
||||
assert "hits" in stats
|
||||
assert "misses" in stats
|
||||
assert "hit_rate" in stats
|
||||
|
||||
def test_clear_cache_all(self):
|
||||
"""Test clear_cache clears all entries."""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
# Add some entries
|
||||
entry = CacheEntry(prompt="test", response="response", created_at=time.time())
|
||||
backend.set("key1", entry)
|
||||
backend.set("key2", entry)
|
||||
|
||||
count = clear_cache()
|
||||
assert count >= 0
|
||||
|
||||
def test_clear_cache_namespace(self):
|
||||
"""Test clear_cache clears specific namespace."""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
# Add entries in different namespaces
|
||||
entry1 = CacheEntry(prompt="test1", response="r1", namespace="ns1", created_at=time.time())
|
||||
entry2 = CacheEntry(prompt="test2", response="r2", namespace="ns2", created_at=time.time())
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
|
||||
count = clear_cache(namespace="ns1")
|
||||
assert count >= 0
|
||||
|
||||
def test_invalidate_pattern(self):
|
||||
"""Test invalidating entries by pattern."""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
# Add entries with different prompts
|
||||
entry1 = CacheEntry(prompt="Python programming", response="r1", created_at=time.time())
|
||||
entry2 = CacheEntry(prompt="Rust programming", response="r2", created_at=time.time())
|
||||
entry3 = CacheEntry(prompt="JavaScript", response="r3", created_at=time.time())
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
backend.set("key3", entry3)
|
||||
|
||||
count = invalidate("Python")
|
||||
assert count >= 0
|
||||
|
||||
def test_invalidate_case_insensitive(self):
|
||||
"""Test invalidate is case insensitive."""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
entry = CacheEntry(prompt="PYTHON programming", response="r1", created_at=time.time())
|
||||
backend.set("key1", entry)
|
||||
|
||||
count = invalidate("python")
|
||||
assert count >= 0
|
||||
|
||||
def test_invalidate_no_matches(self):
|
||||
"""Test invalidate with no matches."""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
entry = CacheEntry(prompt="Rust programming", response="r1", created_at=time.time())
|
||||
backend.set("key1", entry)
|
||||
|
||||
count = invalidate("Python")
|
||||
assert count == 0
|
||||
|
||||
def test_warm_cache(self):
|
||||
"""Test warming cache with prompts."""
|
||||
prompts = ["prompt1", "prompt2"]
|
||||
|
||||
def mock_llm(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
count = warm_cache(prompts, mock_llm, namespace="warm_test")
|
||||
assert count == len(prompts)
|
||||
|
||||
def test_warm_cache_with_failures(self):
|
||||
"""Test warm_cache handles LLM failures gracefully."""
|
||||
|
||||
def failing_llm(prompt: str) -> str:
|
||||
if "fail" in prompt:
|
||||
raise ValueError("LLM error")
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
prompts = ["prompt1", "fail_prompt", "prompt3"]
|
||||
count = warm_cache(prompts, failing_llm, namespace="warm_fail_test")
|
||||
# Should return count even if some prompts fail
|
||||
assert count == len(prompts)
|
||||
|
||||
|
||||
class TestExportCache:
|
||||
"""Tests for export_cache function."""
|
||||
|
||||
def test_export_all_entries(self, tmp_path):
|
||||
"""Test exporting all cache entries."""
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
# Add test entries
|
||||
entry1 = CacheEntry(
|
||||
prompt="test prompt 1",
|
||||
response="response 1",
|
||||
namespace="test_ns",
|
||||
created_at=time.time(),
|
||||
hit_count=5,
|
||||
ttl=3600,
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
)
|
||||
backend.set("key1", entry1)
|
||||
|
||||
entries = export_cache()
|
||||
assert len(entries) >= 0
|
||||
|
||||
if entries:
|
||||
assert "key" in entries[0]
|
||||
assert "prompt" in entries[0]
|
||||
assert "response" in entries[0]
|
||||
assert "namespace" in entries[0]
|
||||
assert "hit_count" in entries[0]
|
||||
|
||||
def test_export_namespace_filtered(self, tmp_path):
|
||||
"""Test exporting entries filtered by namespace."""
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
# Add entries in different namespaces
|
||||
entry1 = CacheEntry(
|
||||
prompt="test1", response="r1", namespace="ns1", created_at=time.time()
|
||||
)
|
||||
entry2 = CacheEntry(
|
||||
prompt="test2", response="r2", namespace="ns2", created_at=time.time()
|
||||
)
|
||||
|
||||
backend.set("key1", entry1)
|
||||
backend.set("key2", entry2)
|
||||
|
||||
entries = export_cache(namespace="ns1")
|
||||
# Should only return entries from ns1
|
||||
assert all(e["namespace"] == "ns1" for e in entries)
|
||||
|
||||
def test_export_to_file(self, tmp_path):
|
||||
"""Test exporting cache to JSON file."""
|
||||
import json
|
||||
|
||||
filepath = tmp_path / "export.json"
|
||||
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response="response",
|
||||
namespace="test",
|
||||
created_at=time.time(),
|
||||
hit_count=3,
|
||||
)
|
||||
backend.set("key1", entry)
|
||||
|
||||
export_cache(filepath=str(filepath))
|
||||
|
||||
# Verify file was created and is valid JSON
|
||||
assert filepath.exists()
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
assert isinstance(data, list)
|
||||
|
||||
def test_export_truncates_large_responses(self, tmp_path):
|
||||
"""Test that large responses are truncated in export."""
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
# Create entry with very large response
|
||||
large_response = "x" * 2000
|
||||
entry = CacheEntry(
|
||||
prompt="test",
|
||||
response=large_response,
|
||||
created_at=time.time(),
|
||||
)
|
||||
backend.set("key1", entry)
|
||||
|
||||
entries = export_cache()
|
||||
if entries:
|
||||
# Response should be truncated to 1000 chars
|
||||
assert len(entries[0]["response"]) <= 1000
|
||||
|
||||
|
||||
class TestStatsIntegration:
|
||||
"""Integration tests for stats with actual cache operations."""
|
||||
|
||||
def test_stats_tracking_with_cache_decorator(self):
|
||||
"""Test that stats are tracked during cache operations."""
|
||||
from semantic_llm_cache import cache
|
||||
|
||||
backend = MemoryBackend()
|
||||
_stats_manager.clear_stats("integration_test")
|
||||
|
||||
@cache(backend=backend, namespace="integration_test")
|
||||
def cached_func(prompt: str) -> str:
|
||||
return f"Response to: {prompt}"
|
||||
|
||||
# Generate activity
|
||||
cached_func("prompt1")
|
||||
cached_func("prompt1") # Hit
|
||||
cached_func("prompt2")
|
||||
|
||||
stats = get_stats("integration_test")
|
||||
assert stats["total_requests"] >= 2
|
||||
|
||||
def test_cache_invalidate_integration(self):
|
||||
"""Test invalidate removes entries from backend."""
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
entry = CacheEntry(
|
||||
prompt="Python is great",
|
||||
response="Yes, it is!",
|
||||
created_at=time.time(),
|
||||
)
|
||||
backend.set("key1", entry)
|
||||
|
||||
# Verify entry exists
|
||||
assert backend.get("key1") is not None
|
||||
|
||||
# Invalidate
|
||||
invalidate("Python")
|
||||
|
||||
# Entry should be gone
|
||||
assert backend.get("key1") is None
|
||||
|
||||
def test_export_includes_metadata(self, tmp_path):
|
||||
"""Test export includes all metadata fields."""
|
||||
backend = _stats_manager.get_backend()
|
||||
backend.clear()
|
||||
|
||||
entry = CacheEntry(
|
||||
prompt="test prompt",
|
||||
response="test response",
|
||||
namespace="export_test",
|
||||
created_at=time.time(),
|
||||
ttl=7200,
|
||||
hit_count=10,
|
||||
input_tokens=500,
|
||||
output_tokens=250,
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
)
|
||||
backend.set("key1", entry)
|
||||
|
||||
entries = export_cache(namespace="export_test")
|
||||
if entries:
|
||||
e = entries[0]
|
||||
assert "created_at" in e
|
||||
assert "ttl" in e
|
||||
assert "hit_count" in e
|
||||
assert "input_tokens" in e
|
||||
assert "output_tokens" in e
|
||||
Loading…
Add table
Add a link
Reference in a new issue