diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a169d27 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +.PHONY: help install install-dev test lint format clean build publish + +help: + @echo "Available commands:" + @echo " make install - Install package" + @echo " make install-dev - Install with dev dependencies" + @echo " make test - Run tests with coverage" + @echo " make lint - Run linting (ruff, black, mypy)" + @echo " make format - Format code with black and ruff" + @echo " make clean - Clean build artifacts" + @echo " make build - Build distribution packages" + @echo " make publish - Publish to PyPI" + +install: + pip install -e . + +install-dev: + pip install -e ".[dev,all]" + +test: + pytest + +lint: + ruff check . + black --check . + mypy . + +format: + ruff check --fix . + black . + +clean: + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + rm -rf .pytest_cache/ + rm -rf htmlcov/ + rm -rf .coverage + find . -type d -name __pycache__ -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + +build: clean + python -m build + +publish: build + twine upload dist/* diff --git a/README.md b/README.md new file mode 100644 index 0000000..90fa683 --- /dev/null +++ b/README.md @@ -0,0 +1,279 @@ +# semantic-llm-cache + +**Async semantic caching for LLM API calls — reduce costs with one decorator.** + +[![PyPI](https://img.shields.io/pypi/v/semantic-llm-cache)](https://pypi.org/project/semantic-llm-cache/) +[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) +[![Python](https://img.shields.io/pypi/pyversions/semantic-llm-cache)](https://pypi.org/project/semantic-llm-cache/) + +> **Fork of [karthyick/prompt-cache](https://github.com/karthyick/prompt-cache)** — fully converted to async for use with async frameworks (FastAPI, aiohttp, Starlette, etc.). + +## Overview + +LLM API calls are expensive and slow. In production applications, **20-40% of prompts are semantically identical** but get charged as separate API calls. `semantic-llm-cache` solves this with a simple decorator that: + +- ✅ **Caches semantically similar prompts** (not just exact matches) +- ✅ **Reduces API costs by 20-40%** +- ✅ **Returns cached responses in <10ms** +- ✅ **Works with any LLM provider** (OpenAI, Anthropic, Ollama, local models) +- ✅ **Fully async** — native `async/await` throughout, no event loop blocking +- ✅ **Auto-detects** sync vs async decorated functions — one decorator for both + +## What changed from the original + +| Area | Original | This fork | +| -------------------- | ------------------------- | ------------------------------------------------------------------- | +| Backends | sync (`sqlite3`, `redis`) | async (`aiosqlite`, `redis.asyncio`) | +| `@cache` decorator | sync only | auto-detects async/sync | +| `EmbeddingCache` | sync `encode()` | adds `async aencode()` via `asyncio.to_thread` | +| `CacheContext` | sync only | supports both `with` and `async with` | +| `CachedLLM` | `chat()` | adds `achat()` | +| Utility functions | sync | `clear_cache`, `invalidate`, `warm_cache`, `export_cache` all async | +| `StorageBackend` ABC | sync abstract methods | all abstract methods are `async def` | +| Min Python | 3.9 | 3.10 (uses `X \| Y` union syntax) | + +## Installation + +Not yet published to PyPI. Install directly from the repository: + +```bash +# Clone +git clone https://github.com/YOUR_ORG/prompt-cache.git +cd prompt-cache + +# Core (exact match only, SQLite backend) +pip install . + +# With semantic similarity (sentence-transformers) +pip install ".[semantic]" + +# With Redis backend +pip install ".[redis]" + +# Everything +pip install ".[all]" +``` + +Or install directly via pip from git: + +```bash +pip install "git+https://github.com/nomyo-ai/.git" +pip install "git+https://github.com/nomyo-ai/async-semantic-llm-cache.git[semantic]" +``` + +## Quick Start + +### Async function (FastAPI, aiohttp, etc.) + +```python +from semantic_llm_cache import cache + +@cache(similarity=0.95, ttl=3600) +async def ask_llm(prompt: str) -> str: + return await call_ollama(prompt) + +# First call — LLM hit +await ask_llm("What is Python?") + +# Second call — cache hit (<10ms, free) +await ask_llm("What's Python?") # 95% similar → cache hit +``` + +### Sync function (backwards compatible) + +```python +from semantic_llm_cache import cache + +@cache() +def ask_llm_sync(prompt: str) -> str: + return call_openai(prompt) # works, but don't use inside a running event loop +``` + +### Semantic Matching + +```python +from semantic_llm_cache import cache + +@cache(similarity=0.90) +async def ask_llm(prompt: str) -> str: + return await call_ollama(prompt) + +await ask_llm("What is Python?") # LLM call +await ask_llm("What's Python?") # cache hit (95% similar) +await ask_llm("Explain Python") # cache hit (91% similar) +await ask_llm("What is Rust?") # LLM call (different topic) +``` + +### SQLite backend (default, persistent) + +```python +from semantic_llm_cache import cache +from semantic_llm_cache.backends import SQLiteBackend + +backend = SQLiteBackend(db_path="my_cache.db") + +@cache(backend=backend, similarity=0.95) +async def ask_llm(prompt: str) -> str: + return await call_ollama(prompt) +``` + +### Redis backend (distributed) + +```python +from semantic_llm_cache import cache +from semantic_llm_cache.backends import RedisBackend + +backend = RedisBackend(url="redis://localhost:6379/0") +await backend.ping() # verify connection (replaces __init__ connection test) + +@cache(backend=backend, similarity=0.95) +async def ask_llm(prompt: str) -> str: + return await call_ollama(prompt) +``` + +### Cache Statistics + +```python +from semantic_llm_cache import get_stats + +stats = get_stats() +# { +# "hits": 1547, +# "misses": 892, +# "hit_rate": 0.634, +# "estimated_savings_usd": 3.09, +# "total_saved_ms": 773500 +# } +``` + +### Cache Management + +```python +from semantic_llm_cache.stats import clear_cache, invalidate + +# Clear all cached entries +await clear_cache() + +# Invalidate entries matching a pattern +await invalidate(pattern="Python") +``` + +### Async context manager + +```python +from semantic_llm_cache import CacheContext + +async with CacheContext(similarity=0.9) as ctx: + result1 = await any_cached_llm_call("prompt 1") + result2 = await any_cached_llm_call("prompt 2") + +print(ctx.stats) # {"hits": 1, "misses": 1} +``` + +### CachedLLM wrapper + +```python +from semantic_llm_cache import CachedLLM + +llm = CachedLLM(similarity=0.9, ttl=3600) +response = await llm.achat("What is Python?", llm_func=my_async_llm) +``` + +## API Reference + +### `@cache()` Decorator + +```python +@cache( + similarity: float = 1.0, # 1.0 = exact match, 0.9 = semantic + ttl: int = 3600, # seconds, None = forever + backend: Backend = None, # None = in-memory + namespace: str = "default", # isolate different use cases + enabled: bool = True, # toggle for debugging + key_func: Callable = None, # custom cache key +) +async def my_llm_function(prompt: str) -> str: + ... +``` + +### Parameters + +| Parameter | Type | Default | Description | +| ------------ | ------------- | ----------- | --------------------------------------------------------- | +| `similarity` | `float` | `1.0` | Cosine similarity threshold (1.0 = exact, 0.9 = semantic) | +| `ttl` | `int \| None` | `3600` | Time-to-live in seconds (None = never expires) | +| `backend` | `Backend` | `None` | Storage backend (None = in-memory) | +| `namespace` | `str` | `"default"` | Isolate different use cases | +| `enabled` | `bool` | `True` | Enable/disable caching | +| `key_func` | `Callable` | `None` | Custom cache key function | + +### Utility Functions + +```python +from semantic_llm_cache import get_stats # sync — safe anywhere +from semantic_llm_cache.stats import ( + clear_cache, # async + invalidate, # async + warm_cache, # async + export_cache, # async +) +``` + +## Backends + +| Backend | Description | I/O | +| --------------- | ------------------------------------ | ------------------------- | +| `MemoryBackend` | In-memory LRU (default) | none — runs in event loop | +| `SQLiteBackend` | Persistent, file-based (`aiosqlite`) | async non-blocking | +| `RedisBackend` | Distributed (`redis.asyncio`) | async non-blocking | + +## Embedding Providers + +| Provider | Quality | Notes | +| ----------------------------- | ---------------------------- | --------------------------- | +| `DummyEmbeddingProvider` | hash-only, no semantic match | zero deps, default | +| `SentenceTransformerProvider` | high (local model) | requires `[semantic]` extra | +| `OpenAIEmbeddingProvider` | high (API) | requires `[openai]` extra | + +Embedding inference is offloaded via `asyncio.to_thread` — model loading is blocking and should be done at application startup, not on first request. + +```python +from semantic_llm_cache.similarity import create_embedding_provider, EmbeddingCache + +# Pre-load at startup (blocking — do this in lifespan, not a request handler) +provider = create_embedding_provider("sentence-transformer") +embedding_cache = EmbeddingCache(provider=provider) + +# Use in request handlers (non-blocking) +embedding = await embedding_cache.aencode("my prompt") +``` + +## Performance + +| Metric | Value | +| -------------------------- | ---------------------------------------- | +| Cache hit latency | <10ms | +| Embedding overhead on miss | ~50ms (sentence-transformers, offloaded) | +| Typical hit rate | 25-40% | +| Cost reduction | 20-40% | + +## Requirements + +- Python >= 3.10 +- numpy >= 1.24.0 +- aiosqlite >= 0.19.0 + +### Optional + +- `sentence-transformers >= 2.2.0` — semantic matching +- `redis >= 4.2.0` — Redis backend (includes `redis.asyncio`) +- `openai >= 1.0.0` — OpenAI embeddings + +## License + +MIT — see [LICENSE](LICENSE). + +## Credits + +Original library by **Karthick Raja M** ([@karthyick](https://github.com/karthyick)). +Async conversion by this fork. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c0c3a32 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,112 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "semantic-llm-cache" +version = "0.2.0" +description = "Async semantic caching for LLM API calls - reduce costs with one decorator" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} +authors = [ + {name = "Karthick Raja M", email = "karthickrajam18@gmail.com"} +] +keywords = [ + "llm", + "cache", + "semantic", + "async", + "openai", + "anthropic", + "ollama", + "prompt", + "optimization", + "cost-reduction", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Framework :: AsyncIO", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + "numpy>=1.24.0", + "aiosqlite>=0.19.0", +] + +[project.optional-dependencies] +semantic = [ + "sentence-transformers>=2.2.0", +] +redis = [ + "redis>=4.2.0", +] +openai = [ + "openai>=1.0.0", +] +all = [ + "sentence-transformers>=2.2.0", + "redis>=4.2.0", + "openai>=1.0.0", +] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "ruff>=0.1.0", + "mypy>=1.0.0", + "build>=1.0.0", + "twine>=4.0.0", + "bump2version>=1.0.1", +] + +[project.urls] +Homepage = "https://github.com/karthyick/prompt-cache" +Documentation = "https://github.com/karthyick/prompt-cache#readme" +Repository = "https://github.com/karthyick/prompt-cache.git" +"Bug Tracker" = "https://github.com/karthyick/prompt-cache/issues" + +[tool.setuptools.packages.find] +include = ["semantic_llm_cache*"] +exclude = ["tests*", "examples*", "docs*"] + +[tool.setuptools.package-data] +semantic_llm_cache = ["py.typed"] + +[tool.black] +line-length = 100 +target-version = ['py310', 'py311', 'py312', 'py313'] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP", "B", "C4"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disable_error_code = ["annotation-unchecked"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +asyncio_mode = "auto" +addopts = "-v --cov=semantic_llm_cache --cov-report=html --cov-report=term-missing --cov-fail-under=90" diff --git a/semantic_llm_cache/__init__.py b/semantic_llm_cache/__init__.py new file mode 100644 index 0000000..d9780bb --- /dev/null +++ b/semantic_llm_cache/__init__.py @@ -0,0 +1,54 @@ +""" +llm-semantic-cache: Semantic caching for LLM API calls. + +Cut LLM costs 30% with one decorator. +""" + +__version__ = "0.1.0" +__author__ = "Karthick Raja M" +__license__ = "MIT" + +# Core exports +from semantic_llm_cache.config import CacheConfig +from semantic_llm_cache.core import ( + CacheContext, + CachedLLM, + cache, + get_default_backend, + set_default_backend, +) +from semantic_llm_cache.exceptions import ( + CacheBackendError, + CacheNotFoundError, + CacheSerializationError, + PromptCacheError, +) +from semantic_llm_cache.stats import CacheStats, clear_cache, get_stats, invalidate +from semantic_llm_cache.storage import StorageBackend + +__all__ = [ + # Version info + "__version__", + "__author__", + "__license__", + # Core API + "cache", + "CacheContext", + "CachedLLM", + "get_default_backend", + "set_default_backend", + # Storage + "StorageBackend", + # Statistics + "CacheStats", + "get_stats", + "clear_cache", + "invalidate", + # Configuration + "CacheConfig", + # Exceptions + "PromptCacheError", + "CacheBackendError", + "CacheSerializationError", + "CacheNotFoundError", +] diff --git a/semantic_llm_cache/backends/__init__.py b/semantic_llm_cache/backends/__init__.py new file mode 100644 index 0000000..aa94669 --- /dev/null +++ b/semantic_llm_cache/backends/__init__.py @@ -0,0 +1,21 @@ +"""Storage backends for llm-semantic-cache.""" + +from semantic_llm_cache.backends.base import BaseBackend +from semantic_llm_cache.backends.memory import MemoryBackend + +try: + from semantic_llm_cache.backends.sqlite import SQLiteBackend +except ImportError: + SQLiteBackend = None # type: ignore + +try: + from semantic_llm_cache.backends.redis import RedisBackend +except ImportError: + RedisBackend = None # type: ignore + +__all__ = [ + "BaseBackend", + "MemoryBackend", + "SQLiteBackend", + "RedisBackend", +] diff --git a/semantic_llm_cache/backends/base.py b/semantic_llm_cache/backends/base.py new file mode 100644 index 0000000..68b8b0a --- /dev/null +++ b/semantic_llm_cache/backends/base.py @@ -0,0 +1,104 @@ +"""Base backend implementation with common functionality.""" + +import time +from typing import Any, Optional + +import numpy as np + +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.storage import StorageBackend + + +def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float: + """Calculate cosine similarity between two vectors. + + Args: + a: First vector + b: Second vector + + Returns: + Similarity score between 0 and 1 + """ + a_arr = np.asarray(a, dtype=np.float32) + b_arr = np.asarray(b, dtype=np.float32) + + dot_product = np.dot(a_arr, b_arr) + norm_a = np.linalg.norm(a_arr) + norm_b = np.linalg.norm(b_arr) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return float(dot_product / (norm_a * norm_b)) + + +class BaseBackend(StorageBackend): + """Base backend with common sync helpers; async public interface via StorageBackend.""" + + def __init__(self) -> None: + """Initialize base backend.""" + self._hits: int = 0 + self._misses: int = 0 + + def _increment_hits(self) -> None: + """Increment hit counter.""" + self._hits += 1 + + def _increment_misses(self) -> None: + """Increment miss counter.""" + self._misses += 1 + + def _check_expired(self, entry: CacheEntry) -> bool: + """Check if entry is expired. + + Args: + entry: CacheEntry to check + + Returns: + True if expired, False otherwise + """ + return entry.is_expired(time.time()) + + def _find_best_match( + self, + candidates: list[tuple[str, CacheEntry]], + query_embedding: list[float], + threshold: float, + ) -> Optional[tuple[str, CacheEntry, float]]: + """Find best matching entry from candidates. + + Sync helper — CPU-only numpy ops, safe to call from async context. + + Args: + candidates: List of (key, entry) tuples + query_embedding: Query embedding vector + threshold: Minimum similarity threshold + + Returns: + (key, entry, similarity) tuple if found above threshold, None otherwise + """ + best_match: Optional[tuple[str, CacheEntry, float]] = None + best_similarity = threshold + + for key, entry in candidates: + if entry.embedding is None: + continue + + similarity = cosine_similarity(query_embedding, entry.embedding) + if similarity > best_similarity: + best_similarity = similarity + best_match = (key, entry, similarity) + + return best_match + + async def get_stats(self) -> dict[str, Any]: + """Get backend statistics. + + Returns: + Dictionary with hits and misses + """ + return { + "hits": self._hits, + "misses": self._misses, + "hit_rate": self._hits / max(self._hits + self._misses, 1), + } diff --git a/semantic_llm_cache/backends/memory.py b/semantic_llm_cache/backends/memory.py new file mode 100644 index 0000000..29a368a --- /dev/null +++ b/semantic_llm_cache/backends/memory.py @@ -0,0 +1,179 @@ +"""In-memory storage backend.""" + +import sys +from typing import Any, Optional + +from semantic_llm_cache.backends.base import BaseBackend +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.exceptions import CacheBackendError + + +class MemoryBackend(BaseBackend): + """In-memory cache storage with LRU eviction. + + All operations are in-memory dict access — no I/O — so async methods + run directly in the event loop without thread offloading. + """ + + def __init__(self, max_size: Optional[int] = None) -> None: + """Initialize memory backend. + + Args: + max_size: Maximum number of entries to store (LRU eviction when reached) + """ + super().__init__() + self._cache: dict[str, CacheEntry] = {} + self._access_order: dict[str, float] = {} + self._max_size = max_size + self._access_counter: float = 0.0 + + def _evict_if_needed(self) -> None: + """Evict oldest entry if at capacity.""" + if self._max_size is None or len(self._cache) < self._max_size: + return + + if self._access_order: + lru_key = min(self._access_order, key=lambda k: self._access_order.get(k, 0)) + del self._cache[lru_key] + del self._access_order[lru_key] + + def _update_access_time(self, key: str) -> None: + """Update access time for LRU tracking.""" + self._access_counter += 1 + self._access_order[key] = self._access_counter + + async def get(self, key: str) -> Optional[CacheEntry]: + """Retrieve cache entry by key. + + Args: + key: Cache key to retrieve + + Returns: + CacheEntry if found and not expired, None otherwise + """ + try: + entry = self._cache.get(key) + if entry is None: + self._increment_misses() + return None + + if self._check_expired(entry): + await self.delete(key) + self._increment_misses() + return None + + self._increment_hits() + self._update_access_time(key) + entry.hit_count += 1 + return entry + except Exception as e: + raise CacheBackendError(f"Failed to get entry: {e}") from e + + async def set(self, key: str, entry: CacheEntry) -> None: + """Store cache entry. + + Args: + key: Cache key to store under + entry: CacheEntry to store + """ + try: + self._evict_if_needed() + self._cache[key] = entry + self._update_access_time(key) + except Exception as e: + raise CacheBackendError(f"Failed to set entry: {e}") from e + + async def delete(self, key: str) -> bool: + """Delete cache entry. + + Args: + key: Cache key to delete + + Returns: + True if entry was deleted, False if not found + """ + try: + if key in self._cache: + del self._cache[key] + self._access_order.pop(key, None) + return True + return False + except Exception as e: + raise CacheBackendError(f"Failed to delete entry: {e}") from e + + async def clear(self) -> None: + """Clear all cache entries.""" + try: + self._cache.clear() + self._access_order.clear() + except Exception as e: + raise CacheBackendError(f"Failed to clear cache: {e}") from e + + async def iterate( + self, namespace: Optional[str] = None + ) -> list[tuple[str, CacheEntry]]: + """Iterate over cache entries, optionally filtered by namespace. + + Args: + namespace: Optional namespace filter + + Returns: + List of (key, entry) tuples + """ + try: + if namespace is None: + return list(self._cache.items()) + + return [ + (k, v) + for k, v in self._cache.items() + if v.namespace == namespace and not self._check_expired(v) + ] + except Exception as e: + raise CacheBackendError(f"Failed to iterate entries: {e}") from e + + async def find_similar( + self, + embedding: list[float], + threshold: float, + namespace: Optional[str] = None, + ) -> Optional[tuple[str, CacheEntry, float]]: + """Find semantically similar cached entry. + + Args: + embedding: Query embedding vector + threshold: Minimum similarity score (0-1) + namespace: Optional namespace filter + + Returns: + (key, entry, similarity) tuple if found above threshold, None otherwise + """ + try: + candidates = [ + (k, v) + for k, v in self._cache.items() + if v.embedding is not None + and not self._check_expired(v) + and (namespace is None or v.namespace == namespace) + ] + return self._find_best_match(candidates, embedding, threshold) + except Exception as e: + raise CacheBackendError(f"Failed to find similar entry: {e}") from e + + async def get_stats(self) -> dict[str, Any]: + """Get backend statistics. + + Returns: + Dictionary with size, memory usage, hits, misses + """ + base_stats = await super().get_stats() + memory_usage = sys.getsizeof(self._cache) + sum( + sys.getsizeof(k) + sys.getsizeof(v) for k, v in self._cache.items() + ) + + return { + **base_stats, + "size": len(self._cache), + "memory_bytes": memory_usage, + "max_size": self._max_size, + } diff --git a/semantic_llm_cache/backends/redis.py b/semantic_llm_cache/backends/redis.py new file mode 100644 index 0000000..6a02f14 --- /dev/null +++ b/semantic_llm_cache/backends/redis.py @@ -0,0 +1,239 @@ +"""Redis distributed storage backend (async via redis.asyncio).""" + +import json +from typing import Any, Optional + +try: + from redis import asyncio as aioredis +except ImportError as err: + raise ImportError( + "Redis backend requires 'redis' package. " + "Install with: pip install semantic-llm-cache[redis]" + ) from err + +from semantic_llm_cache.backends.base import BaseBackend +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.exceptions import CacheBackendError + + +class RedisBackend(BaseBackend): + """Redis-based distributed cache storage (async). + + Uses redis.asyncio (bundled with redis>=4.2) for non-blocking I/O. + The connection is created in __init__; no explicit connect() call needed + as redis.asyncio uses a connection pool that connects lazily. + """ + + DEFAULT_PREFIX = "semantic_llm_cache:" + + def __init__( + self, + url: str = "redis://localhost:6379/0", + prefix: str = DEFAULT_PREFIX, + **kwargs: Any, + ) -> None: + """Initialize Redis backend. + + Args: + url: Redis connection URL + prefix: Key prefix for cache entries + **kwargs: Additional arguments passed to redis.asyncio.from_url + """ + super().__init__() + self._prefix = prefix.rstrip(":") + ":" + self._redis = aioredis.from_url(url, **kwargs) + + async def ping(self) -> None: + """Test Redis connection. Call this after construction to verify connectivity. + + Raises: + CacheBackendError: If Redis is not reachable + """ + try: + await self._redis.ping() + except Exception as e: + raise CacheBackendError(f"Failed to connect to Redis: {e}") from e + + def _make_key(self, key: str) -> str: + """Create full Redis key with prefix.""" + return f"{self._prefix}{key}" + + def _entry_to_dict(self, entry: CacheEntry) -> dict[str, Any]: + """Convert CacheEntry to dictionary for storage.""" + return { + "prompt": entry.prompt, + "response": entry.response, + "embedding": entry.embedding, + "created_at": entry.created_at, + "ttl": entry.ttl, + "namespace": entry.namespace, + "hit_count": entry.hit_count, + "input_tokens": entry.input_tokens, + "output_tokens": entry.output_tokens, + } + + def _dict_to_entry(self, data: dict[str, Any]) -> CacheEntry: + """Convert dictionary from storage to CacheEntry.""" + return CacheEntry( + prompt=data["prompt"], + response=data["response"], + embedding=data.get("embedding"), + created_at=data["created_at"], + ttl=data.get("ttl"), + namespace=data.get("namespace", "default"), + hit_count=data.get("hit_count", 0), + input_tokens=data.get("input_tokens", 0), + output_tokens=data.get("output_tokens", 0), + ) + + async def get(self, key: str) -> Optional[CacheEntry]: + """Retrieve cache entry by key. + + Args: + key: Cache key to retrieve + + Returns: + CacheEntry if found and not expired, None otherwise + """ + try: + redis_key = self._make_key(key) + data = await self._redis.get(redis_key) + + if data is None: + self._increment_misses() + return None + + entry_dict = json.loads(data) + entry = self._dict_to_entry(entry_dict) + + if self._check_expired(entry): + await self.delete(key) + self._increment_misses() + return None + + self._increment_hits() + entry.hit_count += 1 + + entry_dict["hit_count"] = entry.hit_count + await self._redis.set(redis_key, json.dumps(entry_dict)) + + return entry + except Exception as e: + raise CacheBackendError(f"Failed to get entry: {e}") from e + + async def set(self, key: str, entry: CacheEntry) -> None: + """Store cache entry. + + Args: + key: Cache key to store under + entry: CacheEntry to store + """ + try: + redis_key = self._make_key(key) + data = json.dumps(self._entry_to_dict(entry)) + redis_ttl = entry.ttl if entry.ttl is not None else 0 + await self._redis.set(redis_key, data, ex=redis_ttl if redis_ttl > 0 else None) + except Exception as e: + raise CacheBackendError(f"Failed to set entry: {e}") from e + + async def delete(self, key: str) -> bool: + """Delete cache entry. + + Args: + key: Cache key to delete + + Returns: + True if entry was deleted, False if not found + """ + try: + result = await self._redis.delete(self._make_key(key)) + return result > 0 + except Exception as e: + raise CacheBackendError(f"Failed to delete entry: {e}") from e + + async def clear(self) -> None: + """Clear all cache entries with this prefix.""" + try: + keys = await self._redis.keys(f"{self._prefix}*") + if keys: + await self._redis.delete(*keys) + except Exception as e: + raise CacheBackendError(f"Failed to clear cache: {e}") from e + + async def iterate( + self, namespace: Optional[str] = None + ) -> list[tuple[str, CacheEntry]]: + """Iterate over cache entries, optionally filtered by namespace. + + Args: + namespace: Optional namespace filter + + Returns: + List of (key, entry) tuples + """ + try: + keys = await self._redis.keys(f"{self._prefix}*") + results = [] + + for full_key in keys: + short_key = full_key.decode().replace(self._prefix, "", 1) + data = await self._redis.get(full_key) + + if data: + entry_dict = json.loads(data) + entry = self._dict_to_entry(entry_dict) + + if namespace is None or entry.namespace == namespace: + if not self._check_expired(entry): + results.append((short_key, entry)) + + return results + except Exception as e: + raise CacheBackendError(f"Failed to iterate entries: {e}") from e + + async def find_similar( + self, + embedding: list[float], + threshold: float, + namespace: Optional[str] = None, + ) -> Optional[tuple[str, CacheEntry, float]]: + """Find semantically similar cached entry. + + Note: Loads all entries for cosine scan. For large datasets consider + Redis Stack with vector search (RediSearch). + + Args: + embedding: Query embedding vector + threshold: Minimum similarity score (0-1) + namespace: Optional namespace filter + + Returns: + (key, entry, similarity) tuple if found above threshold, None otherwise + """ + try: + entries = await self.iterate(namespace) + candidates = [(k, v) for k, v in entries if v.embedding is not None] + return self._find_best_match(candidates, embedding, threshold) + except Exception as e: + raise CacheBackendError(f"Failed to find similar entry: {e}") from e + + async def get_stats(self) -> dict[str, Any]: + """Get backend statistics.""" + base_stats = await super().get_stats() + + try: + keys = await self._redis.keys(f"{self._prefix}*") + return { + **base_stats, + "size": len(keys) if keys else 0, + "prefix": self._prefix, + } + except Exception as e: + return {**base_stats, "size": 0, "prefix": self._prefix, "error": str(e)} + + async def close(self) -> None: + """Close Redis connection.""" + try: + await self._redis.aclose() + except Exception: + pass diff --git a/semantic_llm_cache/backends/sqlite.py b/semantic_llm_cache/backends/sqlite.py new file mode 100644 index 0000000..5bf200a --- /dev/null +++ b/semantic_llm_cache/backends/sqlite.py @@ -0,0 +1,279 @@ +"""SQLite persistent storage backend (async via aiosqlite).""" + +import json +from pathlib import Path +from typing import Any, Optional + +try: + import aiosqlite +except ImportError as err: + raise ImportError( + "SQLite backend requires 'aiosqlite' package. " + "Install with: pip install semantic-llm-cache[sqlite]" + ) from err + +from semantic_llm_cache.backends.base import BaseBackend +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.exceptions import CacheBackendError + + +class SQLiteBackend(BaseBackend): + """SQLite-based persistent cache storage (async). + + Uses aiosqlite for non-blocking I/O. A single persistent connection + is opened lazily on first use and reused for all subsequent operations. + """ + + def __init__(self, db_path: str | Path = "semantic_cache.db") -> None: + """Initialize SQLite backend. + + Args: + db_path: Path to SQLite database file, or ":memory:" for in-memory DB + """ + super().__init__() + self._db_path = str(db_path) if isinstance(db_path, Path) else db_path + self._conn: Optional[aiosqlite.Connection] = None + + async def _get_conn(self) -> aiosqlite.Connection: + """Get or create the persistent async connection.""" + if self._conn is None: + self._conn = await aiosqlite.connect(self._db_path) + self._conn.row_factory = aiosqlite.Row + await self._initialize_schema() + return self._conn + + async def _initialize_schema(self) -> None: + """Initialize database schema.""" + conn = await self._get_conn() + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS cache_entries ( + key TEXT PRIMARY KEY, + prompt TEXT NOT NULL, + response TEXT NOT NULL, + embedding TEXT, + created_at REAL NOT NULL, + ttl INTEGER, + namespace TEXT NOT NULL DEFAULT 'default', + hit_count INTEGER DEFAULT 0, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0 + ) + """ + ) + await conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_namespace + ON cache_entries(namespace) + """ + ) + await conn.commit() + + def _row_to_entry(self, row: aiosqlite.Row) -> CacheEntry: + """Convert database row to CacheEntry.""" + embedding = None + if row["embedding"]: + embedding = json.loads(row["embedding"]) + + return CacheEntry( + prompt=row["prompt"], + response=json.loads(row["response"]), + embedding=embedding, + created_at=row["created_at"], + ttl=row["ttl"], + namespace=row["namespace"], + hit_count=row["hit_count"], + input_tokens=row["input_tokens"], + output_tokens=row["output_tokens"], + ) + + async def get(self, key: str) -> Optional[CacheEntry]: + """Retrieve cache entry by key. + + Args: + key: Cache key to retrieve + + Returns: + CacheEntry if found and not expired, None otherwise + """ + try: + conn = await self._get_conn() + async with conn.execute( + "SELECT * FROM cache_entries WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + + if row is None: + self._increment_misses() + return None + + entry = self._row_to_entry(row) + + if self._check_expired(entry): + await self.delete(key) + self._increment_misses() + return None + + self._increment_hits() + entry.hit_count += 1 + + await conn.execute( + "UPDATE cache_entries SET hit_count = hit_count + 1 WHERE key = ?", + (key,), + ) + await conn.commit() + + return entry + except Exception as e: + raise CacheBackendError(f"Failed to get entry: {e}") from e + + async def set(self, key: str, entry: CacheEntry) -> None: + """Store cache entry. + + Args: + key: Cache key to store under + entry: CacheEntry to store + """ + try: + conn = await self._get_conn() + embedding_json = json.dumps(entry.embedding) if entry.embedding else None + response_json = json.dumps(entry.response) + + await conn.execute( + """ + INSERT OR REPLACE INTO cache_entries + (key, prompt, response, embedding, created_at, ttl, namespace, + hit_count, input_tokens, output_tokens) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + key, + entry.prompt, + response_json, + embedding_json, + entry.created_at, + entry.ttl, + entry.namespace, + entry.hit_count, + entry.input_tokens, + entry.output_tokens, + ), + ) + await conn.commit() + except Exception as e: + raise CacheBackendError(f"Failed to set entry: {e}") from e + + async def delete(self, key: str) -> bool: + """Delete cache entry. + + Args: + key: Cache key to delete + + Returns: + True if entry was deleted, False if not found + """ + try: + conn = await self._get_conn() + async with conn.execute( + "DELETE FROM cache_entries WHERE key = ?", (key,) + ) as cursor: + rowcount = cursor.rowcount + await conn.commit() + return rowcount > 0 + except Exception as e: + raise CacheBackendError(f"Failed to delete entry: {e}") from e + + async def clear(self) -> None: + """Clear all cache entries.""" + try: + conn = await self._get_conn() + await conn.execute("DELETE FROM cache_entries") + await conn.commit() + except Exception as e: + raise CacheBackendError(f"Failed to clear cache: {e}") from e + + async def iterate( + self, namespace: Optional[str] = None + ) -> list[tuple[str, CacheEntry]]: + """Iterate over cache entries, optionally filtered by namespace. + + Args: + namespace: Optional namespace filter + + Returns: + List of (key, entry) tuples + """ + try: + conn = await self._get_conn() + + if namespace is None: + query = "SELECT key, * FROM cache_entries" + params: tuple[()] = () + else: + query = "SELECT key, * FROM cache_entries WHERE namespace = ?" + params = (namespace,) + + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + + results = [] + for row in rows: + key = row["key"] + entry = self._row_to_entry(row) + if not self._check_expired(entry): + results.append((key, entry)) + + return results + except Exception as e: + raise CacheBackendError(f"Failed to iterate entries: {e}") from e + + async def find_similar( + self, + embedding: list[float], + threshold: float, + namespace: Optional[str] = None, + ) -> Optional[tuple[str, CacheEntry, float]]: + """Find semantically similar cached entry. + + Args: + embedding: Query embedding vector + threshold: Minimum similarity score (0-1) + namespace: Optional namespace filter + + Returns: + (key, entry, similarity) tuple if found above threshold, None otherwise + """ + try: + entries = await self.iterate(namespace) + candidates = [(k, v) for k, v in entries if v.embedding is not None] + return self._find_best_match(candidates, embedding, threshold) + except Exception as e: + raise CacheBackendError(f"Failed to find similar entry: {e}") from e + + async def get_stats(self) -> dict[str, Any]: + """Get backend statistics. + + Returns: + Dictionary with size, database path, hits, misses + """ + base_stats = await super().get_stats() + + try: + conn = await self._get_conn() + async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor: + row = await cursor.fetchone() + size = row[0] if row else 0 + + return { + **base_stats, + "size": size, + "db_path": self._db_path, + } + except Exception as e: + return {**base_stats, "size": 0, "db_path": self._db_path, "error": str(e)} + + async def close(self) -> None: + """Close database connection.""" + if self._conn is not None: + await self._conn.close() + self._conn = None diff --git a/semantic_llm_cache/config.py b/semantic_llm_cache/config.py new file mode 100644 index 0000000..f94f76e --- /dev/null +++ b/semantic_llm_cache/config.py @@ -0,0 +1,61 @@ +"""Configuration management for prompt-cache.""" + +from dataclasses import dataclass +from typing import Any, Callable, Optional + + +@dataclass +class CacheConfig: + """Configuration for cache behavior.""" + + similarity_threshold: float = 1.0 # 1.0 = exact match, lower = semantic + ttl: Optional[int] = 3600 # Time to live in seconds, None = forever + namespace: str = "default" # Isolate different use cases + enabled: bool = True # Enable/disable caching + key_func: Optional[Callable[[Any], str]] = None # Custom cache key function + + # Cost estimation for statistics (USD per 1K tokens) + input_cost_per_1k: float = 0.001 # Default ~$1/1M for cheaper models + output_cost_per_1k: float = 0.002 # Default ~$2/1M for cheaper models + + # Performance settings + max_cache_size: Optional[int] = None # LRU eviction when set + embedding_model: str = "all-MiniLM-L6-v2" # Default sentence-transformer model + + def __post_init__(self) -> None: + """Validate configuration.""" + if not 0.0 <= self.similarity_threshold <= 1.0: + raise ValueError("similarity_threshold must be between 0.0 and 1.0") + if self.ttl is not None and self.ttl <= 0: + raise ValueError("ttl must be positive or None") + if self.max_cache_size is not None and self.max_cache_size <= 0: + raise ValueError("max_cache_size must be positive or None") + + +@dataclass +class CacheEntry: + """A cached response with metadata.""" + + prompt: str + response: Any + embedding: Optional[list[float]] = None # Normalized embedding vector + created_at: float = 0.0 # Unix timestamp + ttl: Optional[int] = None # Time to live in seconds + namespace: str = "default" + hit_count: int = 0 + + # Approximate token counts for cost estimation + input_tokens: int = 0 + output_tokens: int = 0 + + def is_expired(self, current_time: float) -> bool: + """Check if entry has expired based on TTL.""" + if self.ttl is None: + return False + return (current_time - self.created_at) > self.ttl + + def estimate_cost(self, input_cost: float, output_cost: float) -> float: + """Estimate cost savings in USD.""" + input_savings = (self.input_tokens / 1000) * input_cost + output_savings = (self.output_tokens / 1000) * output_cost + return input_savings + output_savings diff --git a/semantic_llm_cache/core.py b/semantic_llm_cache/core.py new file mode 100644 index 0000000..087ef66 --- /dev/null +++ b/semantic_llm_cache/core.py @@ -0,0 +1,369 @@ +"""Core cache decorator and API for llm-semantic-cache.""" + +import functools +import inspect +import time +from typing import Any, Callable, Optional, ParamSpec, TypeVar + +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.backends.base import BaseBackend +from semantic_llm_cache.config import CacheConfig, CacheEntry +from semantic_llm_cache.exceptions import PromptCacheError +from semantic_llm_cache.similarity import EmbeddingCache +from semantic_llm_cache.stats import _stats_manager +from semantic_llm_cache.utils import hash_prompt, normalize_prompt + +P = ParamSpec("P") +R = TypeVar("R") + + +def _extract_prompt(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + """Extract prompt string from function arguments.""" + if args and isinstance(args[0], str): + return args[0] + if "prompt" in kwargs: + return str(kwargs["prompt"]) + return str(args) + str(sorted(kwargs.items())) + + +class CacheContext: + """Context manager for cache configuration. + + Supports both sync (with) and async (async with) usage. + + Examples: + >>> async with CacheContext(similarity=0.9) as ctx: + ... result = await llm_call("prompt") + ... print(ctx.stats) + """ + + def __init__( + self, + similarity: Optional[float] = None, + ttl: Optional[int] = None, + namespace: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> None: + self._config = CacheConfig( + similarity_threshold=similarity if similarity is not None else 1.0, + ttl=ttl, + namespace=namespace if namespace is not None else "default", + enabled=enabled if enabled is not None else True, + ) + self._stats: dict[str, Any] = {"hits": 0, "misses": 0} + + def __enter__(self) -> "CacheContext": + return self + + def __exit__(self, *args: Any) -> None: + pass + + async def __aenter__(self) -> "CacheContext": + return self + + async def __aexit__(self, *args: Any) -> None: + pass + + @property + def stats(self) -> dict[str, Any]: + return self._stats.copy() + + @property + def config(self) -> CacheConfig: + return self._config + + +class CachedLLM: + """Wrapper class for LLM calls with automatic caching. + + Examples: + >>> llm = CachedLLM(similarity=0.9) + >>> response = await llm.achat("What is Python?", llm_func=my_async_llm) + """ + + def __init__( + self, + provider: str = "openai", + model: str = "gpt-4", + similarity: float = 1.0, + ttl: Optional[int] = 3600, + backend: Optional[BaseBackend] = None, + namespace: str = "default", + enabled: bool = True, + ) -> None: + self._provider = provider + self._model = model + self._backend = backend or MemoryBackend() + self._embedding_cache = EmbeddingCache() + self._config = CacheConfig( + similarity_threshold=similarity, + ttl=ttl, + namespace=namespace, + enabled=enabled, + ) + + async def achat( + self, + prompt: str, + llm_func: Optional[Callable[[str], Any]] = None, + **kwargs: Any, + ) -> Any: + """Get response with caching (async). + + Args: + prompt: Input prompt + llm_func: Async or sync LLM function to call on cache miss + **kwargs: Additional arguments for llm_func + + Returns: + LLM response (cached or fresh) + """ + if llm_func is None: + raise ValueError("llm_func is required for CachedLLM.achat()") + + @cache( + similarity=self._config.similarity_threshold, + ttl=self._config.ttl, + backend=self._backend, + namespace=self._config.namespace, + enabled=self._config.enabled, + ) + async def _cached_call(p: str) -> Any: + result = llm_func(p, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + return await _cached_call(prompt) + + +def cache( + similarity: float = 1.0, + ttl: Optional[int] = 3600, + backend: Optional[BaseBackend] = None, + namespace: str = "default", + enabled: bool = True, + key_func: Optional[Callable[..., str]] = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Decorator for caching LLM function responses. + + Auto-detects whether the decorated function is async or sync and returns + the appropriate wrapper. Both variants share identical cache logic. + + Async functions get a true async wrapper (awaits all backend calls). + Sync functions get a sync wrapper that drives the async backends via a + temporary event loop — not suitable inside a running loop; prefer decorating + async functions when integrating with async frameworks like FastAPI. + + Args: + similarity: Cosine similarity threshold (1.0=exact, 0.9=semantic) + ttl: Time-to-live in seconds (None=forever) + backend: Async storage backend (None=in-memory) + namespace: Cache namespace for isolation + enabled: Whether caching is enabled + key_func: Custom cache key function + + Returns: + Decorated function with caching + + Examples: + >>> @cache(similarity=0.9, ttl=3600) + ... async def ask_llm(prompt: str) -> str: + ... return await call_ollama(prompt) + + >>> @cache() + ... def ask_llm_sync(prompt: str) -> str: + ... return call_ollama_sync(prompt) + """ + _backend = backend or MemoryBackend() + embedding_cache = EmbeddingCache() + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + + if inspect.iscoroutinefunction(func): + # ── Async wrapper ──────────────────────────────────────────────── + @functools.wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not enabled: + return await func(*args, **kwargs) + + start_time = time.time() + prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type] + normalized = normalize_prompt(prompt) + cache_key = ( + key_func(*args, **kwargs) # type: ignore[arg-type] + if key_func + else hash_prompt(normalized, namespace) + ) + + # 1. Exact match + entry = await _backend.get(cache_key) + if entry is not None: + latency_ms = (time.time() - start_time) * 1000 + _stats_manager.record_hit( + namespace, + latency_saved_ms=latency_ms, + saved_cost=entry.estimate_cost(0.001, 0.002), + ) + return entry.response # type: ignore[return-value] + + # 2. Semantic match + if similarity < 1.0: + query_embedding = await embedding_cache.aencode(normalized) + result = await _backend.find_similar( + query_embedding, threshold=similarity, namespace=namespace + ) + if result is not None: + _, matched_entry, _ = result + latency_ms = (time.time() - start_time) * 1000 + _stats_manager.record_hit( + namespace, + latency_saved_ms=latency_ms, + saved_cost=matched_entry.estimate_cost(0.001, 0.002), + ) + return matched_entry.response # type: ignore[return-value] + + # 3. Cache miss — call through + _stats_manager.record_miss(namespace) + + try: + response = await func(*args, **kwargs) + except Exception as e: + raise PromptCacheError(f"LLM function call failed: {e}") from e + + embedding = None + if similarity < 1.0: + embedding = await embedding_cache.aencode(normalized) + + await _backend.set( + cache_key, + CacheEntry( + prompt=normalized, + response=response, + embedding=embedding, + created_at=time.time(), + ttl=ttl, + namespace=namespace, + hit_count=0, + input_tokens=len(normalized) // 4, + output_tokens=len(str(response)) // 4, + ), + ) + return response # type: ignore[return-value] + + return async_wrapper # type: ignore[return-value] + + else: + # ── Sync wrapper (backwards compatibility) ─────────────────────── + # Drives async backends via a dedicated event loop per call. + # Do NOT use inside a running event loop (e.g. FastAPI handlers). + import asyncio + + @functools.wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not enabled: + return func(*args, **kwargs) + + start_time = time.time() + prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type] + normalized = normalize_prompt(prompt) + cache_key = ( + key_func(*args, **kwargs) # type: ignore[arg-type] + if key_func + else hash_prompt(normalized, namespace) + ) + + loop = asyncio.new_event_loop() + try: + # 1. Exact match + entry = loop.run_until_complete(_backend.get(cache_key)) + if entry is not None: + latency_ms = (time.time() - start_time) * 1000 + _stats_manager.record_hit( + namespace, + latency_saved_ms=latency_ms, + saved_cost=entry.estimate_cost(0.001, 0.002), + ) + return entry.response # type: ignore[return-value] + + # 2. Semantic match + if similarity < 1.0: + query_embedding = embedding_cache.encode(normalized) + result = loop.run_until_complete( + _backend.find_similar( + query_embedding, threshold=similarity, namespace=namespace + ) + ) + if result is not None: + _, matched_entry, _ = result + latency_ms = (time.time() - start_time) * 1000 + _stats_manager.record_hit( + namespace, + latency_saved_ms=latency_ms, + saved_cost=matched_entry.estimate_cost(0.001, 0.002), + ) + return matched_entry.response # type: ignore[return-value] + + # 3. Cache miss + _stats_manager.record_miss(namespace) + + try: + response = func(*args, **kwargs) + except Exception as e: + raise PromptCacheError(f"LLM function call failed: {e}") from e + + embedding = None + if similarity < 1.0: + embedding = embedding_cache.encode(normalized) + + loop.run_until_complete( + _backend.set( + cache_key, + CacheEntry( + prompt=normalized, + response=response, + embedding=embedding, + created_at=time.time(), + ttl=ttl, + namespace=namespace, + hit_count=0, + input_tokens=len(normalized) // 4, + output_tokens=len(str(response)) // 4, + ), + ) + ) + return response # type: ignore[return-value] + finally: + loop.close() + + return sync_wrapper # type: ignore[return-value] + + return decorator + + +# Global default backend for utility functions +_default_backend: Optional[BaseBackend] = None + + +def get_default_backend() -> BaseBackend: + """Get default storage backend.""" + global _default_backend + if _default_backend is None: + _default_backend = MemoryBackend() + return _default_backend + + +def set_default_backend(backend: BaseBackend) -> None: + """Set default storage backend.""" + global _default_backend + _default_backend = backend + _stats_manager.set_backend(backend) + + +__all__ = [ + "cache", + "CacheContext", + "CachedLLM", + "get_default_backend", + "set_default_backend", +] diff --git a/semantic_llm_cache/exceptions.py b/semantic_llm_cache/exceptions.py new file mode 100644 index 0000000..e3a7287 --- /dev/null +++ b/semantic_llm_cache/exceptions.py @@ -0,0 +1,25 @@ +"""Custom exceptions for prompt-cache.""" + + +class PromptCacheError(Exception): + """Base exception for prompt-cache errors.""" + + pass + + +class CacheBackendError(PromptCacheError): + """Exception raised when backend operations fail.""" + + pass + + +class CacheSerializationError(PromptCacheError): + """Exception raised when serialization/deserialization fails.""" + + pass + + +class CacheNotFoundError(PromptCacheError): + """Exception raised when cache entry is not found.""" + + pass diff --git a/semantic_llm_cache/py.typed b/semantic_llm_cache/py.typed new file mode 100644 index 0000000..a1bf5e9 --- /dev/null +++ b/semantic_llm_cache/py.typed @@ -0,0 +1 @@ +# PEP 561 marker file for type hints diff --git a/semantic_llm_cache/similarity.py b/semantic_llm_cache/similarity.py new file mode 100644 index 0000000..9bafd14 --- /dev/null +++ b/semantic_llm_cache/similarity.py @@ -0,0 +1,283 @@ +"""Embedding generation and similarity matching for llm-semantic-cache.""" + +import asyncio +import hashlib +from functools import lru_cache +from typing import Optional + +import numpy as np + +from semantic_llm_cache.exceptions import PromptCacheError + + +class EmbeddingProvider: + """Base class for embedding providers.""" + + def encode(self, text: str) -> list[float]: + """Generate embedding for text. + + Args: + text: Input text to encode + + Returns: + Embedding vector as list of floats + """ + raise NotImplementedError + + +class DummyEmbeddingProvider(EmbeddingProvider): + """Fallback embedding provider using hash-based vectors. + + Provides consistent embeddings without external dependencies. + Not semantically meaningful but provides consistent cache keys. + """ + + def __init__(self, dim: int = 384) -> None: + """Initialize dummy provider. + + Args: + dim: Embedding dimension (matches MiniLM default) + """ + self._dim = dim + + def encode(self, text: str) -> list[float]: + """Generate hash-based embedding for text. + + Args: + text: Input text to encode + + Returns: + Deterministic embedding vector based on text hash + """ + hash_obj = hashlib.sha256(text.encode()) + hash_bytes = hash_obj.digest() + + values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype( + np.float32 + ) + + if len(values) < self._dim: + values = np.pad(values, (0, self._dim - len(values))) + + norm = np.linalg.norm(values) + if norm > 0: + values = values / norm + + return values.tolist() + + +class SentenceTransformerProvider(EmbeddingProvider): + """Sentence-transformers based embedding provider. + + Uses local models like MiniLM for semantic embeddings. + Inference is CPU/GPU-bound; use aencode() from async contexts. + """ + + def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None: + """Initialize sentence-transformer provider. + + Args: + model_name: Name of sentence-transformer model + """ + try: + from sentence_transformers import SentenceTransformer + except ImportError as e: + raise PromptCacheError( + "sentence-transformers package required for semantic matching. " + "Install with: pip install semantic-llm-cache[semantic]" + ) from e + + self._model = SentenceTransformer(model_name) + self._dim = self._model.get_sentence_embedding_dimension() + + def encode(self, text: str) -> list[float]: + """Generate embedding for text (blocking — use aencode from async code). + + Args: + text: Input text to encode + + Returns: + Normalized embedding vector + """ + embedding = self._model.encode(text, convert_to_numpy=True) + embedding = np.asarray(embedding, dtype=np.float32) + + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + + return embedding.tolist() + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + """OpenAI API-based embedding provider. + + Uses OpenAI's embedding API for high-quality semantic embeddings. + Network I/O — always use aencode() from async contexts. + """ + + def __init__( + self, api_key: Optional[str] = None, model: str = "text-embedding-3-small" + ) -> None: + """Initialize OpenAI embedding provider. + + Args: + api_key: OpenAI API key (uses OPENAI_API_KEY env var if None) + model: OpenAI embedding model to use + """ + try: + import openai + except ImportError as e: + raise PromptCacheError( + "openai package required for OpenAI embeddings. " + "Install with: pip install semantic-llm-cache[openai]" + ) from e + + self._client = openai.OpenAI(api_key=api_key) + self._model = model + + def encode(self, text: str) -> list[float]: + """Generate embedding for text (blocking — use aencode from async code). + + Args: + text: Input text to encode + + Returns: + OpenAI embedding vector (already normalized) + """ + response = self._client.embeddings.create(input=text, model=self._model) + embedding = response.data[0].embedding + + embedding_arr = np.asarray(embedding, dtype=np.float32) + norm = np.linalg.norm(embedding_arr) + if norm > 0: + embedding_arr = embedding_arr / norm + + return embedding_arr.tolist() + + +def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float: + """Calculate cosine similarity between two vectors. + + Args: + a: First vector + b: Second vector + + Returns: + Similarity score between 0 and 1 + + Raises: + ValueError: If vectors have different dimensions + """ + a_arr = np.asarray(a, dtype=np.float32) + b_arr = np.asarray(b, dtype=np.float32) + + if a_arr.shape != b_arr.shape: + raise ValueError( + f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}" + ) + + dot_product = np.dot(a_arr, b_arr) + norm_a = np.linalg.norm(a_arr) + norm_b = np.linalg.norm(b_arr) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return float(dot_product / (norm_a * norm_b)) + + +def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]: + """Helper function for LRU cache encoding. + + Args: + text: Input text + provider: Embedding provider + + Returns: + Embedding as tuple for hashability + """ + return tuple(provider.encode(text)) + + +class EmbeddingCache: + """Cache for embedding generation with LRU eviction. + + Use encode() from sync contexts, aencode() from async contexts. + aencode() offloads blocking inference to a thread pool via asyncio.to_thread. + """ + + def __init__( + self, + provider: Optional[EmbeddingProvider] = None, + cache_size: int = 1024, + ) -> None: + """Initialize embedding cache. + + Args: + provider: Embedding provider (uses DummyEmbeddingProvider if None) + cache_size: Maximum number of embeddings to cache + """ + self._provider = provider or DummyEmbeddingProvider() + self._cache_size = cache_size + self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider) + + def encode(self, text: str) -> list[float]: + """Generate embedding with LRU caching (sync, blocking). + + Args: + text: Input text to encode + + Returns: + Embedding vector + """ + return list(self._get_cached(text, self._provider)) + + async def aencode(self, text: str) -> list[float]: + """Generate embedding with LRU caching (async, non-blocking). + + CPU/network-bound work is offloaded to the default thread pool via + asyncio.to_thread, keeping the event loop free. + + Args: + text: Input text to encode + + Returns: + Embedding vector + """ + return await asyncio.to_thread(self.encode, text) + + def clear_cache(self) -> None: + """Clear the embedding LRU cache.""" + self._get_cached.cache_clear() + + +def create_embedding_provider( + provider_type: str = "auto", + model_name: Optional[str] = None, +) -> EmbeddingProvider: + """Create embedding provider based on type. + + Args: + provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy") + model_name: Optional model name to use + + Returns: + EmbeddingProvider instance + """ + if provider_type == "auto": + try: + return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2") + except PromptCacheError: + return DummyEmbeddingProvider() + + if provider_type == "sentence-transformer": + return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2") + + if provider_type == "openai": + return OpenAIEmbeddingProvider(model=model_name) + + if provider_type == "dummy": + return DummyEmbeddingProvider() + + raise ValueError(f"Unknown provider type: {provider_type}") diff --git a/semantic_llm_cache/stats.py b/semantic_llm_cache/stats.py new file mode 100644 index 0000000..23138e1 --- /dev/null +++ b/semantic_llm_cache/stats.py @@ -0,0 +1,255 @@ +"""Statistics and analytics for llm-semantic-cache.""" + +from dataclasses import dataclass +from threading import Lock +from typing import Any, Callable, Optional + +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.backends.base import BaseBackend + + +@dataclass +class CacheStats: + """Statistics for cache performance.""" + + hits: int = 0 + misses: int = 0 + total_saved_ms: float = 0.0 + estimated_savings_usd: float = 0.0 + + @property + def hit_rate(self) -> float: + """Calculate cache hit rate.""" + total = self.hits + self.misses + return self.hits / max(total, 1) + + @property + def total_requests(self) -> int: + """Get total number of requests.""" + return self.hits + self.misses + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "hits": self.hits, + "misses": self.misses, + "hit_rate": self.hit_rate, + "total_requests": self.total_requests, + "total_saved_ms": self.total_saved_ms, + "estimated_savings_usd": self.estimated_savings_usd, + } + + def __iadd__(self, other: "CacheStats") -> "CacheStats": + self.hits += other.hits + self.misses += other.misses + self.total_saved_ms += other.total_saved_ms + self.estimated_savings_usd += other.estimated_savings_usd + return self + + +class _StatsManager: + """Manager for global cache statistics. + + Uses threading.Lock for record_hit/record_miss — these are simple counter + increments with no awaits inside the lock, so threading.Lock is safe and + avoids the overhead of asyncio.Lock for hot-path calls. + """ + + def __init__(self) -> None: + """Initialize stats manager.""" + self._stats: dict[str, CacheStats] = {} + self._lock = Lock() + self._default_backend: Optional[BaseBackend] = None + + def get_backend(self) -> BaseBackend: + """Get default backend for cache operations.""" + if self._default_backend is None: + self._default_backend = MemoryBackend() + return self._default_backend + + def set_backend(self, backend: BaseBackend) -> None: + """Set default backend for cache operations.""" + with self._lock: + self._default_backend = backend + + def record_hit( + self, + namespace: str, + latency_saved_ms: float = 0.0, + saved_cost: float = 0.0, + ) -> None: + """Record a cache hit (sync, safe to call from async context).""" + with self._lock: + if namespace not in self._stats: + self._stats[namespace] = CacheStats() + stats = self._stats[namespace] + stats.hits += 1 + stats.total_saved_ms += latency_saved_ms + stats.estimated_savings_usd += saved_cost + + def record_miss(self, namespace: str) -> None: + """Record a cache miss (sync, safe to call from async context).""" + with self._lock: + if namespace not in self._stats: + self._stats[namespace] = CacheStats() + self._stats[namespace].misses += 1 + + def get_stats(self, namespace: Optional[str] = None) -> CacheStats: + """Get statistics for namespace or all.""" + with self._lock: + if namespace is not None: + return self._stats.get(namespace, CacheStats()) + + total = CacheStats() + for stats in self._stats.values(): + total += stats + return total + + def clear_stats(self, namespace: Optional[str] = None) -> None: + """Clear statistics for namespace or all.""" + with self._lock: + if namespace is None: + self._stats.clear() + elif namespace in self._stats: + del self._stats[namespace] + + +# Global stats manager instance +_stats_manager = _StatsManager() + + +def get_stats(namespace: Optional[str] = None) -> dict[str, Any]: + """Get cache statistics (sync). + + Args: + namespace: Optional namespace to filter by + + Returns: + Dictionary with cache statistics + """ + return _stats_manager.get_stats(namespace).to_dict() + + +async def clear_cache(namespace: Optional[str] = None) -> int: + """Clear all cached entries (async). + + Args: + namespace: Optional namespace to clear (None = all) + + Returns: + Number of entries cleared + """ + backend = _stats_manager.get_backend() + + if namespace is None: + stats = await backend.get_stats() + size = stats.get("size", 0) + await backend.clear() + _stats_manager.clear_stats() + return size + + entries = await backend.iterate(namespace=namespace) + count = len(entries) + for key, _ in entries: + await backend.delete(key) + _stats_manager.clear_stats(namespace) + return count + + +async def invalidate( + pattern: str, + namespace: Optional[str] = None, +) -> int: + """Invalidate cache entries matching pattern (async). + + Args: + pattern: String pattern to match in prompts + namespace: Optional namespace filter + + Returns: + Number of entries invalidated + """ + backend = _stats_manager.get_backend() + entries = await backend.iterate(namespace=namespace) + count = 0 + pattern_lower = pattern.lower() + + for key, entry in entries: + if pattern_lower in entry.prompt.lower(): + await backend.delete(key) + count += 1 + + return count + + +async def warm_cache( + prompts: list[str], + llm_func: Callable[[str], Any], + namespace: str = "default", +) -> int: + """Pre-populate cache with prompts (async). + + Args: + prompts: List of prompts to cache + llm_func: Async or sync LLM function to call for each prompt + namespace: Cache namespace to use + + Returns: + Number of prompts attempted + """ + import asyncio + import inspect + + from semantic_llm_cache.core import cache + + cached_func = cache(namespace=namespace)(llm_func) + + for prompt in prompts: + try: + result = cached_func(prompt) + if inspect.isawaitable(result): + await result + except Exception: + pass + + return len(prompts) + + +async def export_cache( + namespace: Optional[str] = None, + filepath: Optional[str] = None, +) -> list[dict[str, Any]]: + """Export cache entries for analysis (async). + + Args: + namespace: Optional namespace filter + filepath: Optional file path to save export (JSON) + + Returns: + List of cache entry dictionaries + """ + import json + from datetime import datetime + + backend = _stats_manager.get_backend() + entries = await backend.iterate(namespace=namespace) + + export_data = [] + for key, entry in entries: + export_data.append({ + "key": key, + "prompt": entry.prompt, + "response": str(entry.response)[:1000], + "namespace": entry.namespace, + "hit_count": entry.hit_count, + "created_at": datetime.fromtimestamp(entry.created_at).isoformat(), + "ttl": entry.ttl, + "input_tokens": entry.input_tokens, + "output_tokens": entry.output_tokens, + }) + + if filepath: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(export_data, f, indent=2) + + return export_data diff --git a/semantic_llm_cache/storage.py b/semantic_llm_cache/storage.py new file mode 100644 index 0000000..975180b --- /dev/null +++ b/semantic_llm_cache/storage.py @@ -0,0 +1,111 @@ +"""Storage backend interface for prompt-cache.""" + +from abc import ABC, abstractmethod +from typing import Any, Optional + +from semantic_llm_cache.config import CacheEntry + + +class StorageBackend(ABC): + """Abstract base class for async cache storage backends.""" + + @abstractmethod + async def get(self, key: str) -> Optional[CacheEntry]: + """Retrieve cache entry by key. + + Args: + key: Cache key to retrieve + + Returns: + CacheEntry if found and not expired, None otherwise + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def set(self, key: str, entry: CacheEntry) -> None: + """Store cache entry. + + Args: + key: Cache key to store under + entry: CacheEntry to store + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def delete(self, key: str) -> bool: + """Delete cache entry. + + Args: + key: Cache key to delete + + Returns: + True if entry was deleted, False if not found + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def clear(self) -> None: + """Clear all cache entries. + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def iterate(self, namespace: Optional[str] = None) -> list[tuple[str, CacheEntry]]: + """Iterate over cache entries, optionally filtered by namespace. + + Args: + namespace: Optional namespace filter + + Returns: + List of (key, entry) tuples + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def find_similar( + self, + embedding: list[float], + threshold: float, + namespace: Optional[str] = None, + ) -> Optional[tuple[str, CacheEntry, float]]: + """Find semantically similar cached entry. + + Args: + embedding: Query embedding vector + threshold: Minimum similarity score (0-1) + namespace: Optional namespace filter + + Returns: + (key, entry, similarity) tuple if found above threshold, None otherwise + + Raises: + CacheBackendError: If backend operation fails + """ + pass + + @abstractmethod + async def get_stats(self) -> dict[str, Any]: + """Get backend statistics. + + Returns: + Dictionary with stats like size, memory_usage, etc. + + Raises: + CacheBackendError: If backend operation fails + """ + pass diff --git a/semantic_llm_cache/utils/__init__.py b/semantic_llm_cache/utils/__init__.py new file mode 100644 index 0000000..b4b272f --- /dev/null +++ b/semantic_llm_cache/utils/__init__.py @@ -0,0 +1,97 @@ +"""Utility functions for prompt-cache.""" + +import hashlib +import re +from typing import Any + + +def normalize_prompt(prompt: str) -> str: + """Normalize prompt text for consistent caching. + + Args: + prompt: Raw prompt text + + Returns: + Normalized prompt text + """ + # Remove extra whitespace + prompt = " ".join(prompt.split()) + + # Lowercase for better matching (optional - can affect semantics) + # prompt = prompt.lower() + + # Remove common filler words at start + filler_pattern = r"^(please|can you|could you|i need|i want)\s+" + prompt = re.sub(filler_pattern, "", prompt, flags=re.IGNORECASE) + + # Normalize quotes + prompt = prompt.replace('"', "'").replace("`", "'") + + # Remove trailing punctuation + prompt = prompt.rstrip("?!.") + + return prompt.strip() + + +def hash_prompt(prompt: str, namespace: str = "default") -> str: + """Generate cache key from prompt and namespace. + + Args: + prompt: Prompt text + namespace: Cache namespace + + Returns: + Hash-based cache key + """ + combined = f"{namespace}:{prompt}" + return hashlib.sha256(combined.encode()).hexdigest() + + +def estimate_tokens(text: str) -> int: + """Estimate token count for text (rough approximation). + + Args: + text: Input text + + Returns: + Estimated token count + """ + # Rough approximation: ~4 chars per token + return len(text) // 4 + + +def serialize_response(response: Any) -> str: + """Serialize response for storage. + + Args: + response: Response object (string, dict, etc.) + + Returns: + Serialized JSON string + """ + import json + + return json.dumps(response) + + +def deserialize_response(data: str) -> Any: + """Deserialize response from storage. + + Args: + data: Serialized JSON string + + Returns: + Deserialized response object + """ + import json + + return json.loads(data) + + +__all__ = [ + "normalize_prompt", + "hash_prompt", + "estimate_tokens", + "serialize_response", + "deserialize_response", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..ec0a551 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for prompt-cache.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..60d8e34 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +"""Pytest configuration and fixtures for prompt-cache tests.""" + +import time + +import pytest + +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.config import CacheConfig, CacheEntry + + +@pytest.fixture +def backend(): + """Provide a fresh memory backend for each test.""" + return MemoryBackend() + + +@pytest.fixture +def cache_config(): + """Provide default cache configuration.""" + return CacheConfig() + + +@pytest.fixture +def sample_entry(): + """Provide a sample cache entry.""" + return CacheEntry( + prompt="What is Python?", + response="Python is a programming language.", + embedding=[0.1, 0.2, 0.3], + created_at=time.time(), # Use current time + ttl=3600, + namespace="default", + hit_count=0, + ) + + +@pytest.fixture +def mock_llm_func(): + """Provide a mock LLM function.""" + responses = { + "What is Python?": "Python is a programming language.", + "What's Python?": "Python is a programming language.", + "Explain Python": "Python is a high-level programming language.", + "What is Rust?": "Rust is a systems programming language.", + } + + def _func(prompt: str) -> str: + return responses.get(prompt, f"Response to: {prompt}") + + return _func diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..4b5d94d --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,687 @@ +"""Tests for storage backends.""" + +import json +import time +from unittest.mock import MagicMock, patch + +import pytest + +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.backends.sqlite import SQLiteBackend # noqa: F401 +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.exceptions import CacheBackendError + + +class TestBaseBackend: + """Tests for BaseBackend abstract class.""" + + def test_cosine_similarity(self): + """Test cosine similarity helper method.""" + backend = MemoryBackend() + + entry1 = CacheEntry( + prompt="test", + response="response", + embedding=[1.0, 0.0, 0.0], + ) + entry2 = CacheEntry( + prompt="test", + response="response", + embedding=[1.0, 0.0, 0.0], + ) + entry3 = CacheEntry( + prompt="test", + response="response", + embedding=[0.0, 1.0, 0.0], + ) + + # Test _find_best_match + candidates = [("key1", entry1), ("key2", entry2), ("key3", entry3)] + + # Query matching entry1 + result = backend._find_best_match(candidates, [1.0, 0.0, 0.0], threshold=0.9) + assert result is not None + key, entry, sim = result + assert sim == pytest.approx(1.0) + + +class TestMemoryBackend: + """Tests for MemoryBackend.""" + + def test_set_and_get(self, backend, sample_entry): + """Test basic set and get operations.""" + backend.set("key1", sample_entry) + retrieved = backend.get("key1") + + assert retrieved is not None + assert retrieved.prompt == sample_entry.prompt + assert retrieved.response == sample_entry.response + + def test_get_nonexistent(self, backend): + """Test getting non-existent key returns None.""" + result = backend.get("nonexistent") + assert result is None + + def test_delete(self, backend, sample_entry): + """Test delete operation.""" + backend.set("key1", sample_entry) + assert backend.get("key1") is not None + + assert backend.delete("key1") is True + assert backend.get("key1") is None + + def test_delete_nonexistent(self, backend): + """Test deleting non-existent key returns False.""" + assert backend.delete("nonexistent") is False + + def test_clear(self, backend, sample_entry): + """Test clear operation.""" + backend.set("key1", sample_entry) + backend.set("key2", sample_entry) + backend.clear() + + assert backend.get("key1") is None + assert backend.get("key2") is None + + def test_iterate_all(self, backend): + """Test iterating over all entries.""" + entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time()) + entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time()) + + backend.set("key1", entry1) + backend.set("key2", entry2) + + results = backend.iterate() + assert len(results) == 2 + + def test_iterate_with_namespace(self, backend): + """Test iterating with namespace filter.""" + entry1 = CacheEntry( + prompt="p1", response="r1", namespace="ns1", created_at=time.time() + ) + entry2 = CacheEntry( + prompt="p2", response="r2", namespace="ns2", created_at=time.time() + ) + + backend.set("key1", entry1) + backend.set("key2", entry2) + + results = backend.iterate(namespace="ns1") + assert len(results) == 1 + assert results[0][1].namespace == "ns1" + + def test_find_similar(self, backend): + """Test finding semantically similar entries.""" + entry1 = CacheEntry( + prompt="What is Python?", + response="r1", + embedding=[1.0, 0.0, 0.0], + created_at=time.time(), + ) + entry2 = CacheEntry( + prompt="What is Rust?", + response="r2", + embedding=[0.0, 1.0, 0.0], + created_at=time.time(), + ) + + backend.set("key1", entry1) + backend.set("key2", entry2) + + # Find similar to entry1 + result = backend.find_similar([1.0, 0.0, 0.0], threshold=0.9) + assert result is not None + key, entry, sim = result + assert key == "key1" + + def test_find_similar_no_match(self, backend): + """Test find_similar returns None when below threshold.""" + entry = CacheEntry( + prompt="test", + response="response", + embedding=[1.0, 0.0, 0.0], + created_at=time.time(), + ) + + backend.set("key1", entry) + + # Query with orthogonal vector + result = backend.find_similar([0.0, 1.0, 0.0], threshold=0.9) + assert result is None + + def test_get_stats(self, backend): + """Test get_stats returns correct info.""" + entry = CacheEntry( + prompt="test", + response="response", + created_at=time.time(), + ) + + backend.set("key1", entry) + stats = backend.get_stats() + + assert stats["size"] == 1 + assert "hits" in stats + assert "misses" in stats + + def test_hit_count_increments(self, backend, sample_entry): + """Test hit count increments on cache hit.""" + backend.set("key1", sample_entry) + + backend.get("key1") # First hit + backend.get("key1") # Second hit + + entry = backend.get("key1") + assert entry.hit_count >= 1 + + def test_lru_eviction(self): + """Test LRU eviction when max_size is reached.""" + backend = MemoryBackend(max_size=2) + + entry1 = CacheEntry(prompt="p1", response="r1", created_at=time.time()) + entry2 = CacheEntry(prompt="p2", response="r2", created_at=time.time()) + entry3 = CacheEntry(prompt="p3", response="r3", created_at=time.time()) + + backend.set("key1", entry1) + backend.set("key2", entry2) + assert backend.get_stats()["size"] == 2 + + backend.set("key3", entry3) + # Should evict oldest (key1) + assert backend.get_stats()["size"] == 2 + + def test_expired_entry_not_returned(self, backend): + """Test expired entries are not returned.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=1, + created_at=time.time() - 2, # 2 seconds ago with 1s TTL + ) + + backend.set("key1", entry) + result = backend.get("key1") + assert result is None + + +class TestSQLiteBackend: + """Tests for SQLiteBackend.""" + + @pytest.fixture + def sqlite_backend(self, tmp_path): + """Create SQLite backend with temp database.""" + from semantic_llm_cache.backends.sqlite import SQLiteBackend + + db_path = tmp_path / "test_cache.db" + return SQLiteBackend(db_path) + + def test_set_and_get(self, sqlite_backend, sample_entry): + """Test basic set and get operations.""" + sqlite_backend.set("key1", sample_entry) + retrieved = sqlite_backend.get("key1") + + assert retrieved is not None + assert retrieved.prompt == sample_entry.prompt + assert retrieved.response == sample_entry.response + + def test_persistence(self, sqlite_backend, sample_entry, tmp_path): + """Test entries persist across backend instances.""" + db_path = tmp_path / "test_persist.db" + + # Create first instance + backend1 = SQLiteBackend(db_path) + backend1.set("key1", sample_entry) + + # Create second instance (simulates restart) + backend2 = SQLiteBackend(db_path) + retrieved = backend2.get("key1") + + assert retrieved is not None + assert retrieved.prompt == sample_entry.prompt + + def test_get_stats(self, sqlite_backend): + """Test get_stats returns correct info.""" + entry = CacheEntry(prompt="test", response="response", created_at=time.time()) + sqlite_backend.set("key1", entry) + + stats = sqlite_backend.get_stats() + assert stats["size"] == 1 + assert "db_path" in stats + + def test_clear(self, sqlite_backend, sample_entry): + """Test clear operation.""" + sqlite_backend.set("key1", sample_entry) + sqlite_backend.clear() + + assert sqlite_backend.get("key1") is None + + def test_close_and_reopen(self, sqlite_backend, sample_entry): + """Test closing and reopening connection.""" + sqlite_backend.set("key1", sample_entry) + sqlite_backend.close() + + # Should be able to use after close (reopens connection) + retrieved = sqlite_backend.get("key1") + assert retrieved is not None + + +class TestRedisBackend: + """Tests for RedisBackend.""" + + @pytest.fixture + def mock_redis(self): + """Create mock Redis client.""" + with patch("semantic_llm_cache.backends.redis.redis_lib") as mock: + mock_client = MagicMock() + mock.from_url.return_value = mock_client + mock_client.ping.return_value = True + mock_client.get.return_value = None + mock_client.keys.return_value = [] + mock_client.delete.return_value = 1 + yield mock_client + + @pytest.fixture + def redis_backend(self, mock_redis): + """Create Redis backend with mocked client.""" + from semantic_llm_cache.backends.redis import RedisBackend + + backend = RedisBackend(url="redis://localhost:6379/0") + backend._redis = mock_redis + return backend + + def test_set_and_get(self, redis_backend, mock_redis, sample_entry): + """Test basic set and get operations.""" + # Mock get to return stored data + mock_redis.get.return_value = json.dumps({ + "prompt": sample_entry.prompt, + "response": sample_entry.response, + "embedding": sample_entry.embedding, + "created_at": sample_entry.created_at, + "ttl": sample_entry.ttl, + "namespace": sample_entry.namespace, + "hit_count": 0, + "input_tokens": sample_entry.input_tokens, + "output_tokens": sample_entry.output_tokens, + }).encode() + + redis_backend.set("key1", sample_entry) + retrieved = redis_backend.get("key1") + + assert retrieved is not None + assert retrieved.prompt == sample_entry.prompt + # set is called twice: once for initial set, once to update hit_count + assert mock_redis.set.call_count == 2 + + def test_get_nonexistent(self, redis_backend, mock_redis): + """Test getting non-existent key returns None.""" + mock_redis.get.return_value = None + result = redis_backend.get("nonexistent") + assert result is None + + def test_delete(self, redis_backend, mock_redis, sample_entry): + """Test delete operation.""" + mock_redis.delete.return_value = 1 + result = redis_backend.delete("key1") + assert result is True + mock_redis.delete.assert_called_once() + + def test_delete_nonexistent(self, redis_backend, mock_redis): + """Test deleting non-existent key returns False.""" + mock_redis.delete.return_value = 0 + result = redis_backend.delete("nonexistent") + assert result is False + + def test_clear(self, redis_backend, mock_redis): + """Test clear operation.""" + mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"] + redis_backend.clear() + mock_redis.delete.assert_called_once() + + def test_clear_empty(self, redis_backend, mock_redis): + """Test clear with no entries.""" + mock_redis.keys.return_value = [] + redis_backend.clear() + mock_redis.delete.assert_not_called() + + def test_iterate_all(self, redis_backend, mock_redis, sample_entry): + """Test iterating over all entries.""" + entry_dict = { + "prompt": sample_entry.prompt, + "response": sample_entry.response, + "embedding": sample_entry.embedding, + "created_at": sample_entry.created_at, + "ttl": sample_entry.ttl, + "namespace": sample_entry.namespace, + "hit_count": 0, + "input_tokens": sample_entry.input_tokens, + "output_tokens": sample_entry.output_tokens, + } + + mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"] + mock_redis.get.return_value = json.dumps(entry_dict).encode() + + results = redis_backend.iterate() + assert len(results) == 2 + + def test_iterate_with_namespace(self, redis_backend, mock_redis, sample_entry): + """Test iterating with namespace filter.""" + entry1_dict = { + "prompt": "p1", + "response": "r1", + "embedding": None, + "created_at": time.time(), + "ttl": None, + "namespace": "ns1", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + entry2_dict = entry1_dict.copy() + entry2_dict["namespace"] = "ns2" + + call_count = [0] + + def mock_get(key): + call_count[0] += 1 + if call_count[0] == 1: + return json.dumps(entry1_dict).encode() + return json.dumps(entry2_dict).encode() + + mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"] + mock_redis.get.side_effect = mock_get + + results = redis_backend.iterate(namespace="ns1") + assert len(results) == 1 + assert results[0][1].namespace == "ns1" + + def test_find_similar(self, redis_backend, mock_redis): + """Test finding semantically similar entries.""" + entry_dict = { + "prompt": "What is Python?", + "response": "r1", + "embedding": [1.0, 0.0, 0.0], + "created_at": time.time(), + "ttl": None, + "namespace": "default", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + + mock_redis.keys.return_value = [b"semantic_llm_cache:key1"] + mock_redis.get.return_value = json.dumps(entry_dict).encode() + + result = redis_backend.find_similar([1.0, 0.0, 0.0], threshold=0.9) + assert result is not None + key, entry, sim = result + assert key == "key1" + + def test_find_similar_no_match(self, redis_backend, mock_redis): + """Test find_similar returns None when below threshold.""" + entry_dict = { + "prompt": "test", + "response": "response", + "embedding": [1.0, 0.0, 0.0], + "created_at": time.time(), + "ttl": None, + "namespace": "default", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + + mock_redis.keys.return_value = [b"semantic_llm_cache:key1"] + mock_redis.get.return_value = json.dumps(entry_dict).encode() + + # Query with orthogonal vector + result = redis_backend.find_similar([0.0, 1.0, 0.0], threshold=0.9) + assert result is None + + def test_get_stats(self, redis_backend, mock_redis): + """Test get_stats returns correct info.""" + mock_redis.keys.return_value = [b"semantic_llm_cache:key1", b"semantic_llm_cache:key2"] + stats = redis_backend.get_stats() + + assert "prefix" in stats + assert stats["size"] == 2 + assert stats["prefix"] == "semantic_llm_cache:" + + def test_get_stats_error_handling(self, redis_backend, mock_redis): + """Test get_stats handles Redis errors gracefully.""" + mock_redis.keys.side_effect = Exception("Connection lost") + stats = redis_backend.get_stats() + + assert "error" in stats + assert stats["size"] == 0 + + def test_make_key(self, redis_backend): + """Test key prefixing.""" + result = redis_backend._make_key("test_key") + assert result == "semantic_llm_cache:test_key" + + def test_entry_to_dict(self, redis_backend, sample_entry): + """Test converting entry to dictionary.""" + result = redis_backend._entry_to_dict(sample_entry) + assert result["prompt"] == sample_entry.prompt + assert result["response"] == sample_entry.response + assert result["embedding"] == sample_entry.embedding + + def test_dict_to_entry(self, redis_backend): + """Test converting dictionary to entry.""" + data = { + "prompt": "test", + "response": "response", + "embedding": [1.0, 0.0], + "created_at": time.time(), + "ttl": 100, + "namespace": "test_ns", + "hit_count": 5, + "input_tokens": 100, + "output_tokens": 50, + } + + entry = redis_backend._dict_to_entry(data) + assert entry.prompt == "test" + assert entry.namespace == "test_ns" + assert entry.hit_count == 5 + + def test_dict_to_entry_defaults(self, redis_backend): + """Test dict_to_entry uses defaults for missing fields.""" + data = { + "prompt": "test", + "response": "response", + "created_at": time.time(), + } + + entry = redis_backend._dict_to_entry(data) + assert entry.embedding is None + assert entry.ttl is None + assert entry.namespace == "default" + assert entry.hit_count == 0 + assert entry.input_tokens == 0 + assert entry.output_tokens == 0 + + def test_connection_failure(self): + """Test connection failure raises CacheBackendError.""" + from semantic_llm_cache.backends.redis import RedisBackend + from semantic_llm_cache.exceptions import CacheBackendError + + # Need to patch both the import and the from_url call + with patch("semantic_llm_cache.backends.redis.redis_lib") as mock_redis: + mock_client = MagicMock() + mock_client.ping.side_effect = Exception("Connection refused") + mock_redis.from_url.return_value = mock_client + + with pytest.raises(CacheBackendError, match="Failed to connect"): + RedisBackend(url="redis://localhost:9999/0") + + def test_set_with_ttl(self, redis_backend, mock_redis, sample_entry): + """Test setting entry with TTL.""" + sample_entry.ttl = 3600 + redis_backend.set("key1", sample_entry) + + call_args = mock_redis.set.call_args + assert call_args[1]["ex"] == 3600 + + def test_get_expired_entry(self, redis_backend, mock_redis): + """Test expired entry is not returned.""" + expired_dict = { + "prompt": "test", + "response": "response", + "embedding": None, + "created_at": time.time() - 1000, + "ttl": 100, # 100 seconds TTL, created 1000 seconds ago + "namespace": "default", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + + mock_redis.get.return_value = json.dumps(expired_dict).encode() + mock_redis.delete.return_value = 1 + + result = redis_backend.get("expired_key") + assert result is None + mock_redis.delete.assert_called_once() + + def test_close(self, redis_backend, mock_redis): + """Test closing Redis connection.""" + redis_backend.close() + mock_redis.close.assert_called_once() + + def test_close_error_handling(self, redis_backend, mock_redis): + """Test close handles errors gracefully.""" + mock_redis.close.side_effect = Exception("Close error") + # Should not raise + redis_backend.close() + + def test_iterate_with_expired_entries(self, redis_backend, mock_redis): + """Test iterate filters out expired entries.""" + expired_dict = { + "prompt": "expired", + "response": "response", + "embedding": None, + "created_at": time.time() - 1000, + "ttl": 100, + "namespace": "default", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + valid_dict = { + "prompt": "valid", + "response": "response", + "embedding": None, + "created_at": time.time(), + "ttl": None, + "namespace": "default", + "hit_count": 0, + "input_tokens": 0, + "output_tokens": 0, + } + + call_count = [0] + + def mock_get(key): + call_count[0] += 1 + if call_count[0] == 1: + return json.dumps(expired_dict).encode() + return json.dumps(valid_dict).encode() + + mock_redis.keys.return_value = [b"semantic_llm_cache:expired", b"semantic_llm_cache:valid"] + mock_redis.get.side_effect = mock_get + + results = redis_backend.iterate() + # Only valid entry should be returned + assert len(results) == 1 + assert results[0][1].prompt == "valid" + + def test_set_error_handling(self, redis_backend, mock_redis, sample_entry): + """Test set handles Redis errors.""" + from semantic_llm_cache.exceptions import CacheBackendError + + mock_redis.set.side_effect = Exception("Redis error") + + with pytest.raises(CacheBackendError, match="Failed to set"): + redis_backend.set("key1", sample_entry) + + def test_delete_error_handling(self, redis_backend, mock_redis): + """Test delete handles Redis errors.""" + from semantic_llm_cache.exceptions import CacheBackendError + + mock_redis.delete.side_effect = Exception("Redis error") + + with pytest.raises(CacheBackendError, match="Failed to delete"): + redis_backend.delete("key1") + + def test_iterate_error_handling(self, redis_backend, mock_redis): + """Test iterate handles Redis errors.""" + from semantic_llm_cache.exceptions import CacheBackendError + + mock_redis.keys.side_effect = Exception("Redis error") + + with pytest.raises(CacheBackendError, match="Failed to iterate"): + redis_backend.iterate() + + def test_get_json_error(self, redis_backend, mock_redis): + """Test get handles invalid JSON.""" + import json + mock_redis.get.return_value = b"invalid json" + + # The JSON decode error should be wrapped in CacheBackendError + with pytest.raises((CacheBackendError, json.JSONDecodeError)): + redis_backend.get("key1") + + def test_import_error_without_package(self): + """Test ImportError when redis package not installed.""" + # This test validates the import guard in redis.py + from semantic_llm_cache.backends import redis as redis_module + + # Check that RedisBackend is defined + assert hasattr(redis_module, "RedisBackend") + + +class TestCacheEntry: + """Tests for CacheEntry dataclass.""" + + def test_is_expired_with_none_ttl(self): + """Test entry with None TTL never expires.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=None, + created_at=time.time() - 10000, + ) + assert not entry.is_expired(time.time()) + + def test_is_expired_with_ttl(self): + """Test entry with TTL expires correctly.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=10, + created_at=time.time() - 15, + ) + assert entry.is_expired(time.time()) + + def test_is_expired_not_yet(self): + """Test entry not yet expired.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=10, + created_at=time.time() - 5, + ) + assert not entry.is_expired(time.time()) + + def test_estimate_cost(self): + """Test cost estimation.""" + entry = CacheEntry( + prompt="test", + response="response", + input_tokens=1000, + output_tokens=500, + ) + cost = entry.estimate_cost(0.001, 0.002) + # 1000/1000 * 0.001 + 500/1000 * 0.002 = 0.001 + 0.001 = 0.002 + assert cost == pytest.approx(0.002) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..d920cee --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,606 @@ +"""Tests for core cache decorator and API.""" + +import time +from unittest.mock import MagicMock + +import pytest + +from semantic_llm_cache import CacheContext, CachedLLM, cache, get_default_backend, set_default_backend +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.config import CacheConfig, CacheEntry +from semantic_llm_cache.exceptions import PromptCacheError + + +class TestCacheDecorator: + """Tests for @cache decorator.""" + + def test_exact_match_cache_hit(self, mock_llm_func): + """Test exact match caching returns cached result.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str) -> str: + call_count["count"] += 1 + return mock_llm_func(prompt) + + # First call - cache miss + result1 = cached_func("What is Python?") + assert result1 == "Python is a programming language." + assert call_count["count"] == 1 + + # Second call - cache hit + result2 = cached_func("What is Python?") + assert result2 == "Python is a programming language." + assert call_count["count"] == 1 # No additional call + + def test_cache_miss_different_prompt(self, mock_llm_func): + """Test different prompts result in cache misses.""" + @cache() + def cached_func(prompt: str) -> str: + return mock_llm_func(prompt) + + result1 = cached_func("What is Python?") + result2 = cached_func("What is Rust?") + + assert result1 == "Python is a programming language." + assert result2 == "Rust is a systems programming language." + + def test_cache_disabled(self, mock_llm_func): + """Test caching can be disabled.""" + call_count = {"count": 0} + + @cache(enabled=False) + def cached_func(prompt: str) -> str: + call_count["count"] += 1 + return mock_llm_func(prompt) + + cached_func("What is Python?") + cached_func("What is Python?") + + assert call_count["count"] == 2 # Both calls hit the function + + def test_custom_namespace(self, mock_llm_func): + """Test custom namespace isolates cache.""" + @cache(namespace="test") + def cached_func(prompt: str) -> str: + return mock_llm_func(prompt) + + result = cached_func("What is Python?") + assert result == "Python is a programming language." + + def test_ttl_expiration(self, mock_llm_func): + """Test TTL expiration works.""" + @cache(ttl=1) # 1 second TTL + def cached_func(prompt: str) -> str: + return mock_llm_func(prompt) + + # First call + result1 = cached_func("What is Python?") + assert result1 == "Python is a programming language." + + # Immediate second call - should hit cache + result2 = cached_func("What is Python?") + assert result2 == "Python is a programming language." + + # Wait for expiration + time.sleep(1.1) + + # Third call - should miss due to TTL + # Note: This test may be flaky in slow CI environments + result3 = cached_func("What is Python?") + assert result3 == "Python is a programming language." + + def test_cache_with_exception(self): + """Test cache handles exceptions properly.""" + @cache() + def failing_func(prompt: str) -> str: + raise ValueError("LLM API error") + + with pytest.raises(PromptCacheError): + failing_func("test prompt") + + def test_semantic_similarity_match(self, mock_llm_func): + """Test semantic similarity matching.""" + call_count = {"count": 0} + + @cache(similarity=0.85) + def cached_func(prompt: str) -> str: + call_count["count"] += 1 + return mock_llm_func(prompt) + + # First call + cached_func("What is Python?") + assert call_count["count"] == 1 + + # Similar prompts may hit cache depending on embedding + # Note: With dummy embeddings, exact string matching determines similarity + cached_func("What is Python?") # Exact match + assert call_count["count"] == 1 + + +class TestCacheContext: + """Tests for CacheContext manager.""" + + def test_context_manager(self): + """Test CacheContext works as context manager.""" + with CacheContext(similarity=0.9, ttl=1800) as ctx: + assert ctx.config.similarity_threshold == 0.9 + assert ctx.config.ttl == 1800 + + def test_context_stats(self): + """Test context tracks stats.""" + with CacheContext() as ctx: + stats = ctx.stats + assert "hits" in stats + assert "misses" in stats + + +class TestCachedLLM: + """Tests for CachedLLM wrapper class.""" + + def test_init(self): + """Test CachedLLM initialization.""" + llm = CachedLLM(provider="openai", model="gpt-4") + assert llm._provider == "openai" + assert llm._model == "gpt-4" + + def test_chat_with_llm_func(self): + """Test chat method with custom LLM function.""" + llm = CachedLLM() + + def mock_llm(prompt: str) -> str: + return f"Response to: {prompt}" + + result = llm.chat("Hello", llm_func=mock_llm) + assert result == "Response to: Hello" + + def test_chat_caches_responses(self, mock_llm_func): + """Test chat caches responses.""" + llm = CachedLLM() + call_count = {"count": 0} + + def counting_llm(prompt: str) -> str: + call_count["count"] += 1 + return mock_llm_func(prompt) + + llm.chat("What is Python?", llm_func=counting_llm) + llm.chat("What is Python?", llm_func=counting_llm) + + # Should cache (depends on embedding, may not with dummy) + assert call_count["count"] >= 1 + + +class TestBackendManagement: + """Tests for backend management functions.""" + + def test_get_default_backend(self): + """Test get_default_backend returns a backend.""" + backend = get_default_backend() + assert backend is not None + assert isinstance(backend, MemoryBackend) + + def test_set_default_backend(self): + """Test set_default_backend changes default.""" + custom_backend = MemoryBackend(max_size=10) + set_default_backend(custom_backend) + + backend = get_default_backend() + assert backend is custom_backend + + +class TestCacheEntry: + """Tests for CacheEntry class.""" + + def test_entry_creation(self): + """Test CacheEntry creation.""" + entry = CacheEntry( + prompt="test", + response="response", + created_at=time.time(), + ) + assert entry.prompt == "test" + assert entry.response == "response" + + def test_is_expired_no_ttl(self): + """Test entry without TTL never expires.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=None, + created_at=time.time() - 1000, + ) + assert not entry.is_expired(time.time()) + + def test_is_expired_with_ttl(self): + """Test entry with TTL expires correctly.""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=1, # 1 second + created_at=time.time() - 2, # 2 seconds ago + ) + assert entry.is_expired(time.time()) + + def test_estimate_cost(self): + """Test cost estimation.""" + entry = CacheEntry( + prompt="test", + response="response", + input_tokens=100, + output_tokens=50, + ) + cost = entry.estimate_cost(0.001, 0.002) + # 100/1000 * 0.001 + 50/1000 * 0.002 = 0.0001 + 0.0001 = 0.0002 + assert abs(cost - 0.0002) < 1e-6 + + +class TestCacheConfig: + """Tests for CacheConfig class.""" + + def test_default_config(self): + """Test default configuration values.""" + config = CacheConfig() + assert config.similarity_threshold == 1.0 + assert config.ttl == 3600 + assert config.namespace == "default" + assert config.enabled is True + + def test_custom_config(self): + """Test custom configuration values.""" + config = CacheConfig( + similarity_threshold=0.9, + ttl=7200, + namespace="custom", + enabled=False, + ) + assert config.similarity_threshold == 0.9 + assert config.ttl == 7200 + assert config.namespace == "custom" + assert config.enabled is False + + def test_invalid_similarity(self): + """Test invalid similarity raises error.""" + with pytest.raises(ValueError, match="similarity_threshold"): + CacheConfig(similarity_threshold=1.5) + + def test_invalid_ttl(self): + """Test invalid TTL raises error.""" + with pytest.raises(ValueError, match="ttl"): + CacheConfig(ttl=-1) + + def test_invalid_max_size(self): + """Test invalid max_size raises error.""" + with pytest.raises(ValueError, match="max_cache_size"): + CacheConfig(max_cache_size=0) + + +class TestCacheDecoratorEdgeCases: + """Tests for edge cases in cache decorator.""" + + def test_cache_with_kwargs_only(self, mock_llm_func): + """Test caching when function is called with kwargs only.""" + @cache() + def cached_func(prompt: str) -> str: + return mock_llm_func(prompt) + + result = cached_func(prompt="What is Python?") + assert result == "Python is a programming language." + + def test_cache_with_multiple_args(self): + """Test caching with multiple arguments.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str, temperature: float = 0.7) -> str: + call_count["count"] += 1 + return f"Response to: {prompt} at {temperature}" + + cached_func("test", 0.5) + cached_func("test", 0.5) + # Same prompt hits cache even with different temperature + # (cache key is based on first arg) + cached_func("test", 0.9) + cached_func("different", 0.5) + + # First call + different prompt = 2 calls + assert call_count["count"] == 2 + + def test_cache_with_custom_key_func(self): + """Test custom key function.""" + call_count = {"count": 0} + + def custom_key(prompt: str, temperature: float = 0.7) -> str: + return f"{prompt}:{temperature}" + + @cache(key_func=custom_key) + def cached_func(prompt: str, temperature: float = 0.7) -> str: + call_count["count"] += 1 + return f"Response to: {prompt} at {temperature}" + + cached_func("test", 0.7) + cached_func("test", 0.7) + assert call_count["count"] == 1 + + def test_semantic_match_threshold_edge(self, mock_llm_func): + """Test semantic matching at threshold boundaries.""" + call_count = {"count": 0} + + @cache(similarity=0.5) # Lower threshold + def cached_func(prompt: str) -> str: + call_count["count"] += 1 + return mock_llm_func(prompt) + + # First call + cached_func("What is Python?") + # Similar query may hit with lower threshold + cached_func("What is Python?") # Exact match always hits + assert call_count["count"] == 1 + + def test_cache_with_none_response(self): + """Test caching None responses.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str) -> None: + call_count["count"] += 1 + return None + + result1 = cached_func("test") + result2 = cached_func("test") + + assert result1 is None + assert result2 is None + assert call_count["count"] == 1 # Should cache + + def test_cache_with_empty_string_response(self): + """Test caching empty string responses.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str) -> str: + call_count["count"] += 1 + return "" + + result1 = cached_func("test") + result2 = cached_func("test") + + assert result1 == "" + assert result2 == "" + assert call_count["count"] == 1 # Should cache + + def test_cache_with_dict_response(self): + """Test caching dict responses.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str) -> dict: + call_count["count"] += 1 + return {"key": "value", "number": 42} + + result1 = cached_func("test") + cached_func("test") # Second call to verify caching + + assert result1 == {"key": "value", "number": 42} + assert call_count["count"] == 1 + + def test_cache_with_list_response(self): + """Test caching list responses.""" + call_count = {"count": 0} + + @cache() + def cached_func(prompt: str) -> list: + call_count["count"] += 1 + return [1, 2, 3, 4, 5] + + result1 = cached_func("test") + cached_func("test") # Second call to verify caching + + assert result1 == [1, 2, 3, 4, 5] + assert call_count["count"] == 1 + + +class TestCacheDecoratorErrorPaths: + """Tests for error handling in cache decorator.""" + + def test_backend_set_raises_error(self): + """Test that backend.set errors propagate.""" + from semantic_llm_cache.exceptions import CacheBackendError + + backend = MagicMock() + # Backend set wraps exceptions in CacheBackendError + backend.set.side_effect = CacheBackendError("Storage error") + backend.get.return_value = None # No cached value + + @cache(backend=backend) + def cached_func(prompt: str) -> str: + return "response" + + # The CacheBackendError from backend.set should propagate + with pytest.raises(CacheBackendError, match="Storage error"): + cached_func("test") + + def test_backend_get_raises_error(self): + """Test that backend.get errors propagate.""" + from semantic_llm_cache.exceptions import CacheBackendError + + backend = MagicMock() + # Backend get wraps exceptions in CacheBackendError + backend.get.side_effect = CacheBackendError("Get error") + + @cache(backend=backend) + def cached_func(prompt: str) -> str: + return "response" + + # The CacheBackendError from backend.get should propagate + with pytest.raises(CacheBackendError, match="Get error"): + cached_func("test") + + def test_llm_error_still_wrapped(self): + """Test that LLM errors are still wrapped in PromptCacheError.""" + from semantic_llm_cache.exceptions import PromptCacheError + + @cache() + def failing_func(prompt: str) -> str: + raise ValueError("LLM API error") + + with pytest.raises(PromptCacheError): + failing_func("test") + + +class TestCacheContextAdvanced: + """Advanced tests for CacheContext.""" + + def test_context_with_zero_similarity(self): + """Test context with zero similarity (accept all).""" + with CacheContext(similarity=0.0) as ctx: + assert ctx.config.similarity_threshold == 0.0 + + def test_context_with_infinite_ttl(self): + """Test context with infinite TTL (None).""" + with CacheContext(ttl=None) as ctx: + assert ctx.config.ttl is None + + def test_context_disabled(self): + """Test disabled context.""" + with CacheContext(enabled=False) as ctx: + assert ctx.config.enabled is False + + +class TestCachedLLMAdvanced: + """Advanced tests for CachedLLM.""" + + def test_chat_with_kwargs(self): + """Test chat with additional kwargs passed to LLM.""" + llm = CachedLLM() + + def mock_llm(prompt: str, temperature: float = 0.7, max_tokens: int = 100) -> str: + return f"Response to: {prompt} (temp={temperature}, tokens={max_tokens})" + + result = llm.chat("Hello", llm_func=mock_llm, temperature=0.5, max_tokens=200) + assert "temp=0.5" in result + assert "tokens=200" in result + + def test_cached_llm_with_custom_backend(self): + """Test CachedLLM with custom backend.""" + custom_backend = MemoryBackend(max_size=5) + llm = CachedLLM(backend=custom_backend) + + assert llm._backend is custom_backend + + def test_cached_llm_different_namespaces(self): + """Test CachedLLM with different namespaces.""" + llm1 = CachedLLM(namespace="ns1") + llm2 = CachedLLM(namespace="ns2") + + assert llm1._config.namespace == "ns1" + assert llm2._config.namespace == "ns2" + + +class TestBackendManagementAdvanced: + """Advanced tests for backend management.""" + + def test_multiple_default_backend_changes(self): + """Test changing default backend multiple times.""" + backend1 = MemoryBackend(max_size=10) + backend2 = MemoryBackend(max_size=20) + backend3 = MemoryBackend(max_size=30) + + set_default_backend(backend1) + assert get_default_backend() is backend1 + + set_default_backend(backend2) + assert get_default_backend() is backend2 + + set_default_backend(backend3) + assert get_default_backend() is backend3 + + def test_backend_persists_stats(self): + """Test backend stats persist across get_default_backend calls.""" + backend = get_default_backend() + + # Create an entry + from semantic_llm_cache.config import CacheEntry + entry = CacheEntry( + prompt="test", + response="response", + created_at=time.time() + ) + backend.set("key1", entry) + + # Get stats + stats = backend.get_stats() + assert stats["size"] == 1 + + +class TestCacheEntryEdgeCases: + """Edge case tests for CacheEntry.""" + + def test_entry_with_zero_tokens(self): + """Test entry with zero token counts.""" + entry = CacheEntry( + prompt="test", + response="response", + input_tokens=0, + output_tokens=0, + ) + cost = entry.estimate_cost(0.001, 0.002) + assert cost == 0.0 + + def test_entry_with_large_token_count(self): + """Test entry with large token counts.""" + entry = CacheEntry( + prompt="test", + response="response", + input_tokens=100000, + output_tokens=50000, + ) + cost = entry.estimate_cost(0.001, 0.002) + assert cost > 0 + + def test_entry_with_negative_ttl(self): + """Test entry creation handles negative TTL (becomes expired).""" + entry = CacheEntry( + prompt="test", + response="response", + ttl=-1, + created_at=time.time(), + ) + # Negative TTL means immediately expired + assert entry.is_expired(time.time()) + + def test_entry_hit_count_initialization(self): + """Test entry initializes with zero hit count.""" + entry = CacheEntry( + prompt="test", + response="response", + created_at=time.time(), + ) + assert entry.hit_count == 0 + + +class TestCacheConfigEdgeCases: + """Edge case tests for CacheConfig.""" + + def test_config_boundary_similarity(self): + """Test similarity at valid boundaries.""" + config1 = CacheConfig(similarity_threshold=0.0) + assert config1.similarity_threshold == 0.0 + + config2 = CacheConfig(similarity_threshold=1.0) + assert config2.similarity_threshold == 1.0 + + def test_config_zero_ttl(self): + """Test zero TTL is rejected (validation requires positive).""" + # The validation in CacheConfig rejects ttl <= 0 + with pytest.raises(ValueError, match="ttl"): + CacheConfig(ttl=0) + + def test_config_very_large_ttl(self): + """Test very large TTL.""" + config = CacheConfig(ttl=86400 * 365) # 1 year + assert config.ttl == 86400 * 365 + + def test_config_with_special_namespace(self): + """Test namespace with special characters.""" + config = CacheConfig(namespace="test-ns_123.v1") + assert config.namespace == "test-ns_123.v1" diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..a6a566f --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,330 @@ +"""Integration tests for prompt-cache.""" + +import time + +import pytest + +from semantic_llm_cache import cache, clear_cache, get_stats, invalidate +from semantic_llm_cache.backends import MemoryBackend + + +class TestEndToEnd: + """End-to-end integration tests.""" + + def test_full_cache_workflow(self): + """Test complete cache workflow from hit to miss.""" + backend = MemoryBackend() + call_count = {"count": 0} + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + call_count["count"] += 1 + return f"Response to: {prompt}" + + # First call - miss + result1 = llm_function("What is Python?") + assert result1 == "Response to: What is Python?" + assert call_count["count"] == 1 + + # Second call - hit + result2 = llm_function("What is Python?") + assert result2 == "Response to: What is Python?" + assert call_count["count"] == 1 + + # Different prompt - miss + result3 = llm_function("What is Rust?") + assert result3 == "Response to: What is Rust?" + assert call_count["count"] == 2 + + def test_stats_integration(self): + """Test statistics tracking.""" + backend = MemoryBackend() + + @cache(backend=backend, namespace="test") + def llm_function(prompt: str) -> str: + return f"Response to: {prompt}" + + # Generate some activity + llm_function("prompt 1") + llm_function("prompt 1") # Hit + llm_function("prompt 2") + + stats = get_stats(namespace="test") + assert stats["total_requests"] >= 2 + + def test_clear_cache_integration(self): + """Test clearing cache affects function behavior.""" + backend = MemoryBackend() + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + return f"Response to: {prompt}" + + llm_function("test prompt") + + # Clear cache + cleared = clear_cache() + assert cleared >= 0 + + # Function should still work + result = llm_function("test prompt") + assert result == "Response to: test prompt" + + def test_invalidate_integration(self): + """Test invalidating cache entries.""" + backend = MemoryBackend() + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + return f"Response to: {prompt}" + + llm_function("Python programming") + llm_function("Rust programming") + + # Invalidate Python entries + count = invalidate("Python") + assert count >= 0 + + def test_multiple_namespaces(self): + """Test cache isolation across namespaces.""" + backend = MemoryBackend() + + @cache(backend=backend, namespace="app1") + def app1_llm(prompt: str) -> str: + return f"App1: {prompt}" + + @cache(backend=backend, namespace="app2") + def app2_llm(prompt: str) -> str: + return f"App2: {prompt}" + + result1 = app1_llm("test") + result2 = app2_llm("test") + + assert result1 == "App1: test" + assert result2 == "App2: test" + + def test_ttl_expiration_integration(self): + """Test TTL expiration in real workflow.""" + backend = MemoryBackend() + + @cache(backend=backend, ttl=1) # 1 second TTL + def llm_function(prompt: str) -> str: + return f"Response to: {prompt}" + + llm_function("test prompt") + + # Immediate second call - hit + llm_function("test prompt") + + # Wait for expiration + time.sleep(1.5) + + # Should miss (cached entry expired) + llm_function("test prompt") + + +class TestComplexScenarios: + """Tests for complex real-world scenarios.""" + + def test_high_volume_caching(self): + """Test cache behavior with many entries.""" + backend = MemoryBackend(max_size=100) + call_count = {"count": 0} + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + call_count["count"] += 1 + return f"Response {call_count['count']}" + + # Add many entries + for i in range(150): + llm_function(f"prompt {i}") + + # Some entries should have been evicted + stats = backend.get_stats() + assert stats["size"] <= 100 + + def test_concurrent_like_access(self): + """Test multiple calls to same cached entry.""" + backend = MemoryBackend() + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + return f"Unique: {time.time()}" + + # Multiple calls + results = [llm_function("test") for _ in range(5)] + + # All should return same result (cached) + assert len(set(results)) == 1 + + def test_different_return_types(self): + """Test caching different return types.""" + backend = MemoryBackend() + + @cache(backend=backend) + def return_dict(prompt: str) -> dict: + return {"key": "value"} + + @cache(backend=backend) + def return_list(prompt: str) -> list: + return [1, 2, 3] + + @cache(backend=backend) + def return_string(prompt: str) -> str: + return "string response" + + # Use unique prompts to avoid cache collision + assert isinstance(return_dict("test_dict"), dict) + assert isinstance(return_list("test_list"), list) + assert isinstance(return_string("test_string"), str) + + def test_empty_and_none_responses(self): + """Test caching empty and None responses.""" + backend = MemoryBackend() + + @cache(backend=backend) + def return_empty(prompt: str) -> str: + return "" + + @cache(backend=backend) + def return_none(prompt: str) -> None: + return None + + assert return_empty("empty_test") == "" + assert return_none("none_test") is None + + # Should still cache (second calls should hit cache) + assert return_empty("empty_test") == "" + assert return_none("none_test") is None + + +class TestErrorHandling: + """Tests for error handling in various scenarios.""" + + def test_function_with_exception(self): + """Test function that raises exception.""" + from semantic_llm_cache.exceptions import PromptCacheError + + backend = MemoryBackend() + + @cache(backend=backend) + def failing_function(prompt: str) -> str: + if "error" in prompt: + raise ValueError("Test error") + return "OK" + + # Normal call works + assert failing_function("normal") == "OK" + + # Error call raises PromptCacheError (wrapped exception) + with pytest.raises(PromptCacheError): + failing_function("error prompt") + + # Normal call still works + assert failing_function("normal") == "OK" + + def test_backend_error_handling(self): + """Test that backend wraps errors properly.""" + from semantic_llm_cache.backends.memory import MemoryBackend + + # Use MemoryBackend which has proper error handling + backend = MemoryBackend() + + @cache(backend=backend) + def working_func(prompt: str) -> str: + return f"Response to: {prompt}" + + # Normal operation works + assert working_func("test") == "Response to: test" + + # Second call hits cache + assert working_func("test") == "Response to: test" + + # Backend properly stores and retrieves entries + stats = backend.get_stats() + assert stats["hits"] >= 1 + + +class TestPromptNormalization: + """Tests for prompt normalization effects.""" + + def test_whitespace_normalization(self): + """Test prompts with different whitespace are cached separately.""" + backend = MemoryBackend() + call_count = {"count": 0} + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + call_count["count"] += 1 + return f"Response: {prompt}" + + llm_function("What is Python?") + llm_function("What is Python?") # Extra spaces + + # Normalization should make these the same + # Note: This depends on the normalization implementation + assert call_count["count"] >= 1 + + def test_case_sensitivity(self): + """Test case sensitivity in caching.""" + backend = MemoryBackend() + call_count = {"count": 0} + + @cache(backend=backend) + def llm_function(prompt: str) -> str: + call_count["count"] += 1 + return f"Response: {prompt}" + + llm_function("What is Python?") + llm_function("what is python?") + + # Case differences create different cache entries + # (normalization doesn't lowercase by default) + assert call_count["count"] >= 1 + + +class TestConfigurationCombinations: + """Tests for various configuration combinations.""" + + def test_no_caching_config(self): + """Test configuration with caching disabled.""" + backend = MemoryBackend() + + @cache(backend=backend, enabled=False) + def llm_function(prompt: str) -> str: + return f"Response: {time.time()}" + + result1 = llm_function("test") + time.sleep(0.01) + result2 = llm_function("test") + + # Without caching, results differ + assert result1 != result2 + + def test_zero_ttl(self): + """Test zero TTL means immediate expiration.""" + backend = MemoryBackend() + + @cache(backend=backend, ttl=0) + def llm_function(prompt: str) -> str: + return f"Response: {prompt}" + + llm_function("test") + # Entry immediately expires, so next call is a miss + llm_function("test") + + def test_infinite_ttl(self): + """Test None TTL means never expire.""" + backend = MemoryBackend() + call_count = {"count": 0} + + @cache(backend=backend, ttl=None) + def llm_function(prompt: str) -> str: + call_count["count"] += 1 + return f"Response: {prompt}" + + llm_function("test") + llm_function("test") + + assert call_count["count"] == 1 diff --git a/tests/test_similarity.py b/tests/test_similarity.py new file mode 100644 index 0000000..a6c697c --- /dev/null +++ b/tests/test_similarity.py @@ -0,0 +1,208 @@ +"""Tests for embedding generation and similarity matching.""" + +import numpy as np +import pytest + +from semantic_llm_cache.similarity import ( + DummyEmbeddingProvider, + EmbeddingCache, + OpenAIEmbeddingProvider, + SentenceTransformerProvider, + cosine_similarity, + create_embedding_provider, +) + + +class TestCosineSimilarity: + """Tests for cosine_similarity function.""" + + def test_identical_vectors(self): + """Test identical vectors have similarity 1.0.""" + a = [1.0, 2.0, 3.0] + b = [1.0, 2.0, 3.0] + assert cosine_similarity(a, b) == pytest.approx(1.0) + + def test_orthogonal_vectors(self): + """Test orthogonal vectors have similarity 0.0.""" + a = [1.0, 0.0, 0.0] + b = [0.0, 1.0, 0.0] + assert cosine_similarity(a, b) == pytest.approx(0.0) + + def test_opposite_vectors(self): + """Test opposite vectors have similarity -1.0.""" + a = [1.0, 2.0, 3.0] + b = [-1.0, -2.0, -3.0] + assert cosine_similarity(a, b) == pytest.approx(-1.0) + + def test_zero_vectors(self): + """Test zero vectors return 0.0.""" + a = [0.0, 0.0, 0.0] + b = [1.0, 2.0, 3.0] + assert cosine_similarity(a, b) == 0.0 + + def test_numpy_array_input(self): + """Test function accepts numpy arrays.""" + a = np.array([1.0, 2.0, 3.0]) + b = np.array([1.0, 2.0, 3.0]) + assert cosine_similarity(a, b) == pytest.approx(1.0) + + def test_mixed_dimensions(self): + """Test vectors of different dimensions raise error.""" + a = [1.0, 2.0] + b = [1.0, 2.0, 3.0] + # Should raise ValueError for mismatched dimensions + with pytest.raises(ValueError, match="dimension mismatch"): + cosine_similarity(a, b) + + +class TestDummyEmbeddingProvider: + """Tests for DummyEmbeddingProvider.""" + + def test_encode_returns_list(self): + """Test encode returns list of floats.""" + provider = DummyEmbeddingProvider() + embedding = provider.encode("test prompt") + assert isinstance(embedding, list) + assert all(isinstance(x, float) for x in embedding) + + def test_encode_deterministic(self): + """Test same input produces same output.""" + provider = DummyEmbeddingProvider() + text = "test prompt" + e1 = provider.encode(text) + e2 = provider.encode(text) + assert e1 == e2 + + def test_encode_different_inputs(self): + """Test different inputs produce different outputs.""" + provider = DummyEmbeddingProvider() + e1 = provider.encode("prompt 1") + e2 = provider.encode("prompt 2") + assert e1 != e2 + + def test_custom_dimension(self): + """Test custom embedding dimension.""" + provider = DummyEmbeddingProvider(dim=128) + embedding = provider.encode("test") + assert len(embedding) == 128 + + def test_embedding_normalized(self): + """Test embeddings are normalized to unit length.""" + provider = DummyEmbeddingProvider() + embedding = provider.encode("test prompt") + # Calculate norm + norm = np.linalg.norm(embedding) + assert norm == pytest.approx(1.0, rel=1e-5) + + +class TestSentenceTransformerProvider: + """Tests for SentenceTransformerProvider.""" + + @pytest.mark.skip(reason="Requires sentence-transformers installation") + def test_encode_returns_list(self): + """Test encode returns list of floats.""" + provider = SentenceTransformerProvider() + embedding = provider.encode("test prompt") + assert isinstance(embedding, list) + assert all(isinstance(x, float) for x in embedding) + + @pytest.mark.skip(reason="Requires sentence-transformers installation") + def test_encode_deterministic(self): + """Test same input produces same output.""" + provider = SentenceTransformerProvider() + text = "test prompt" + e1 = provider.encode(text) + e2 = provider.encode(text) + assert e1 == e2 + + def test_import_error_without_package(self, monkeypatch): + """Test ImportError raised when package not installed.""" + # Skip if sentence-transformers is installed + pytest.importorskip("sentence_transformers", reason="sentence-transformers is installed, cannot test import error") + + +class TestOpenAIEmbeddingProvider: + """Tests for OpenAIEmbeddingProvider.""" + + @pytest.mark.skip(reason="Requires OpenAI API key") + def test_encode_returns_list(self): + """Test encode returns list of floats.""" + provider = OpenAIEmbeddingProvider() + embedding = provider.encode("test prompt") + assert isinstance(embedding, list) + assert all(isinstance(x, float) for x in embedding) + + @pytest.mark.skip(reason="Requires OpenAI API key") + def test_encode_deterministic(self): + """Test same input produces same output.""" + provider = OpenAIEmbeddingProvider() + text = "test prompt" + e1 = provider.encode(text) + e2 = provider.encode(text) + # OpenAI embeddings may vary slightly + assert len(e1) == len(e2) + + def test_import_error_without_package(self, monkeypatch): + """Test ImportError raised when package not installed.""" + # Skip if openai is installed + pytest.importorskip("openai", reason="openai is installed, cannot test import error") + + +class TestEmbeddingCache: + """Tests for EmbeddingCache.""" + + def test_cache_provider(self): + """Test cache uses provider.""" + provider = DummyEmbeddingProvider() + cache = EmbeddingCache(provider=provider) + + # Encode same text twice + e1 = cache.encode("test prompt") + e2 = cache.encode("test prompt") + + assert e1 == e2 + + def test_cache_clear(self): + """Test cache can be cleared.""" + provider = DummyEmbeddingProvider() + cache = EmbeddingCache(provider=provider) + + cache.encode("test prompt") + cache.clear_cache() + + # Should still work after clear + embedding = cache.encode("test prompt") + assert len(embedding) > 0 + + def test_cache_default_provider(self): + """Test cache uses dummy provider by default.""" + cache = EmbeddingCache() + embedding = cache.encode("test prompt") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +class TestCreateEmbeddingProvider: + """Tests for create_embedding_provider factory.""" + + def test_create_dummy_provider(self): + """Test creating dummy provider.""" + provider = create_embedding_provider("dummy") + assert isinstance(provider, DummyEmbeddingProvider) + + def test_create_auto_provider(self): + """Test auto provider creates sentence-transformers when available.""" + provider = create_embedding_provider("auto") + # Creates SentenceTransformerProvider if available, else DummyEmbeddingProvider + assert isinstance(provider, (DummyEmbeddingProvider, SentenceTransformerProvider)) + + def test_invalid_provider_type(self): + """Test invalid provider type raises error.""" + with pytest.raises(ValueError, match="Unknown provider type"): + create_embedding_provider("invalid_type") + + def test_custom_model_name(self, monkeypatch): + """Test custom model name is passed through.""" + # This would work with actual sentence-transformers + provider = create_embedding_provider("dummy", model_name="custom-model") + assert isinstance(provider, DummyEmbeddingProvider) diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..d238e38 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,443 @@ +"""Tests for statistics and analytics module.""" + +import time + +import pytest + +from semantic_llm_cache.backends import MemoryBackend +from semantic_llm_cache.config import CacheEntry +from semantic_llm_cache.stats import ( + CacheStats, + _stats_manager, + clear_cache, + export_cache, + get_stats, + invalidate, + warm_cache, +) + + +@pytest.fixture(autouse=True) +def clear_stats_state(): + """Clear stats state before each test.""" + _stats_manager.clear_stats() + yield + + +class TestCacheStats: + """Tests for CacheStats dataclass.""" + + def test_default_values(self): + """Test CacheStats initializes with defaults.""" + stats = CacheStats() + assert stats.hits == 0 + assert stats.misses == 0 + assert stats.total_saved_ms == 0.0 + assert stats.estimated_savings_usd == 0.0 + + def test_hit_rate_empty(self): + """Test hit rate with no requests.""" + stats = CacheStats() + assert stats.hit_rate == 0.0 + + def test_hit_rate_all_hits(self): + """Test hit rate with all cache hits.""" + stats = CacheStats(hits=10, misses=0) + assert stats.hit_rate == 1.0 + + def test_hit_rate_all_misses(self): + """Test hit rate with all cache misses.""" + stats = CacheStats(hits=0, misses=10) + assert stats.hit_rate == 0.0 + + def test_hit_rate_mixed(self): + """Test hit rate with mixed hits and misses.""" + stats = CacheStats(hits=7, misses=3) + assert stats.hit_rate == 0.7 + + def test_total_requests(self): + """Test total requests calculation.""" + stats = CacheStats(hits=5, misses=3) + assert stats.total_requests == 8 + + def test_to_dict(self): + """Test converting stats to dictionary.""" + stats = CacheStats(hits=10, misses=5, total_saved_ms=1000.0, estimated_savings_usd=0.5) + result = stats.to_dict() + + assert result["hits"] == 10 + assert result["misses"] == 5 + assert result["hit_rate"] == 2/3 + assert result["total_requests"] == 15 + assert result["total_saved_ms"] == 1000.0 + assert result["estimated_savings_usd"] == 0.5 + + def test_iadd(self): + """Test in-place addition of stats.""" + stats1 = CacheStats(hits=5, misses=3, total_saved_ms=500.0, estimated_savings_usd=0.25) + stats2 = CacheStats(hits=3, misses=2, total_saved_ms=300.0, estimated_savings_usd=0.15) + + stats1 += stats2 + + assert stats1.hits == 8 + assert stats1.misses == 5 + assert stats1.total_saved_ms == 800.0 + assert stats1.estimated_savings_usd == 0.4 + + +class TestStatsManager: + """Tests for _StatsManager.""" + + def test_record_hit(self): + """Test recording a cache hit.""" + _stats_manager.record_hit("test_ns", latency_saved_ms=100.0, saved_cost=0.01) + + stats = _stats_manager.get_stats("test_ns") + assert stats.hits == 1 + assert stats.total_saved_ms == 100.0 + assert stats.estimated_savings_usd == 0.01 + + def test_record_multiple_hits(self): + """Test recording multiple hits.""" + _stats_manager.record_hit("test_ns", latency_saved_ms=50.0, saved_cost=0.005) + _stats_manager.record_hit("test_ns", latency_saved_ms=75.0, saved_cost=0.008) + + stats = _stats_manager.get_stats("test_ns") + assert stats.hits == 2 + assert stats.total_saved_ms == 125.0 + + def test_record_miss(self): + """Test recording a cache miss.""" + _stats_manager.record_miss("test_ns") + + stats = _stats_manager.get_stats("test_ns") + assert stats.misses == 1 + + def test_get_stats_namespace(self): + """Test getting stats for specific namespace.""" + _stats_manager.record_hit("ns1", latency_saved_ms=100.0) + _stats_manager.record_miss("ns1") + _stats_manager.record_hit("ns2", latency_saved_ms=50.0) + + stats1 = _stats_manager.get_stats("ns1") + stats2 = _stats_manager.get_stats("ns2") + + assert stats1.hits == 1 + assert stats1.misses == 1 + assert stats2.hits == 1 + assert stats2.misses == 0 + + def test_get_stats_all_namespaces(self): + """Test getting aggregated stats for all namespaces.""" + _stats_manager.record_hit("ns1", latency_saved_ms=100.0) + _stats_manager.record_hit("ns2", latency_saved_ms=50.0) + _stats_manager.record_miss("ns1") + + stats = _stats_manager.get_stats(None) # All namespaces + assert stats.hits == 2 + assert stats.misses == 1 + + def test_get_stats_nonexistent_namespace(self): + """Test getting stats for namespace with no activity.""" + stats = _stats_manager.get_stats("nonexistent") + assert stats.hits == 0 + assert stats.misses == 0 + + def test_clear_stats_namespace(self): + """Test clearing stats for specific namespace.""" + _stats_manager.record_hit("ns1", latency_saved_ms=100.0) + _stats_manager.record_hit("ns2", latency_saved_ms=50.0) + + _stats_manager.clear_stats("ns1") + + stats1 = _stats_manager.get_stats("ns1") + stats2 = _stats_manager.get_stats("ns2") + + assert stats1.hits == 0 + assert stats2.hits == 1 + + def test_clear_stats_all(self): + """Test clearing all stats.""" + _stats_manager.record_hit("ns1", latency_saved_ms=100.0) + _stats_manager.record_hit("ns2", latency_saved_ms=50.0) + + _stats_manager.clear_stats() + + stats = _stats_manager.get_stats(None) + assert stats.hits == 0 + assert stats.misses == 0 + + def test_set_backend(self): + """Test setting default backend.""" + custom_backend = MemoryBackend(max_size=10) + _stats_manager.set_backend(custom_backend) + + retrieved = _stats_manager.get_backend() + assert retrieved is custom_backend + + +class TestPublicStatsAPI: + """Tests for public stats API functions.""" + + def test_get_stats(self): + """Test get_stats returns dictionary.""" + _stats_manager.record_hit("test_ns", latency_saved_ms=100.0) + + stats = get_stats("test_ns") + assert isinstance(stats, dict) + assert "hits" in stats + assert "misses" in stats + assert "hit_rate" in stats + + def test_clear_cache_all(self): + """Test clear_cache clears all entries.""" + backend = _stats_manager.get_backend() + + # Add some entries + entry = CacheEntry(prompt="test", response="response", created_at=time.time()) + backend.set("key1", entry) + backend.set("key2", entry) + + count = clear_cache() + assert count >= 0 + + def test_clear_cache_namespace(self): + """Test clear_cache clears specific namespace.""" + backend = _stats_manager.get_backend() + + # Add entries in different namespaces + entry1 = CacheEntry(prompt="test1", response="r1", namespace="ns1", created_at=time.time()) + entry2 = CacheEntry(prompt="test2", response="r2", namespace="ns2", created_at=time.time()) + + backend.set("key1", entry1) + backend.set("key2", entry2) + + count = clear_cache(namespace="ns1") + assert count >= 0 + + def test_invalidate_pattern(self): + """Test invalidating entries by pattern.""" + backend = _stats_manager.get_backend() + + # Add entries with different prompts + entry1 = CacheEntry(prompt="Python programming", response="r1", created_at=time.time()) + entry2 = CacheEntry(prompt="Rust programming", response="r2", created_at=time.time()) + entry3 = CacheEntry(prompt="JavaScript", response="r3", created_at=time.time()) + + backend.set("key1", entry1) + backend.set("key2", entry2) + backend.set("key3", entry3) + + count = invalidate("Python") + assert count >= 0 + + def test_invalidate_case_insensitive(self): + """Test invalidate is case insensitive.""" + backend = _stats_manager.get_backend() + + entry = CacheEntry(prompt="PYTHON programming", response="r1", created_at=time.time()) + backend.set("key1", entry) + + count = invalidate("python") + assert count >= 0 + + def test_invalidate_no_matches(self): + """Test invalidate with no matches.""" + backend = _stats_manager.get_backend() + + entry = CacheEntry(prompt="Rust programming", response="r1", created_at=time.time()) + backend.set("key1", entry) + + count = invalidate("Python") + assert count == 0 + + def test_warm_cache(self): + """Test warming cache with prompts.""" + prompts = ["prompt1", "prompt2"] + + def mock_llm(prompt: str) -> str: + return f"Response to: {prompt}" + + count = warm_cache(prompts, mock_llm, namespace="warm_test") + assert count == len(prompts) + + def test_warm_cache_with_failures(self): + """Test warm_cache handles LLM failures gracefully.""" + + def failing_llm(prompt: str) -> str: + if "fail" in prompt: + raise ValueError("LLM error") + return f"Response to: {prompt}" + + prompts = ["prompt1", "fail_prompt", "prompt3"] + count = warm_cache(prompts, failing_llm, namespace="warm_fail_test") + # Should return count even if some prompts fail + assert count == len(prompts) + + +class TestExportCache: + """Tests for export_cache function.""" + + def test_export_all_entries(self, tmp_path): + """Test exporting all cache entries.""" + backend = _stats_manager.get_backend() + backend.clear() + + # Add test entries + entry1 = CacheEntry( + prompt="test prompt 1", + response="response 1", + namespace="test_ns", + created_at=time.time(), + hit_count=5, + ttl=3600, + input_tokens=100, + output_tokens=50, + ) + backend.set("key1", entry1) + + entries = export_cache() + assert len(entries) >= 0 + + if entries: + assert "key" in entries[0] + assert "prompt" in entries[0] + assert "response" in entries[0] + assert "namespace" in entries[0] + assert "hit_count" in entries[0] + + def test_export_namespace_filtered(self, tmp_path): + """Test exporting entries filtered by namespace.""" + backend = _stats_manager.get_backend() + backend.clear() + + # Add entries in different namespaces + entry1 = CacheEntry( + prompt="test1", response="r1", namespace="ns1", created_at=time.time() + ) + entry2 = CacheEntry( + prompt="test2", response="r2", namespace="ns2", created_at=time.time() + ) + + backend.set("key1", entry1) + backend.set("key2", entry2) + + entries = export_cache(namespace="ns1") + # Should only return entries from ns1 + assert all(e["namespace"] == "ns1" for e in entries) + + def test_export_to_file(self, tmp_path): + """Test exporting cache to JSON file.""" + import json + + filepath = tmp_path / "export.json" + + backend = _stats_manager.get_backend() + backend.clear() + + entry = CacheEntry( + prompt="test", + response="response", + namespace="test", + created_at=time.time(), + hit_count=3, + ) + backend.set("key1", entry) + + export_cache(filepath=str(filepath)) + + # Verify file was created and is valid JSON + assert filepath.exists() + with open(filepath) as f: + data = json.load(f) + assert isinstance(data, list) + + def test_export_truncates_large_responses(self, tmp_path): + """Test that large responses are truncated in export.""" + backend = _stats_manager.get_backend() + backend.clear() + + # Create entry with very large response + large_response = "x" * 2000 + entry = CacheEntry( + prompt="test", + response=large_response, + created_at=time.time(), + ) + backend.set("key1", entry) + + entries = export_cache() + if entries: + # Response should be truncated to 1000 chars + assert len(entries[0]["response"]) <= 1000 + + +class TestStatsIntegration: + """Integration tests for stats with actual cache operations.""" + + def test_stats_tracking_with_cache_decorator(self): + """Test that stats are tracked during cache operations.""" + from semantic_llm_cache import cache + + backend = MemoryBackend() + _stats_manager.clear_stats("integration_test") + + @cache(backend=backend, namespace="integration_test") + def cached_func(prompt: str) -> str: + return f"Response to: {prompt}" + + # Generate activity + cached_func("prompt1") + cached_func("prompt1") # Hit + cached_func("prompt2") + + stats = get_stats("integration_test") + assert stats["total_requests"] >= 2 + + def test_cache_invalidate_integration(self): + """Test invalidate removes entries from backend.""" + backend = _stats_manager.get_backend() + backend.clear() + + entry = CacheEntry( + prompt="Python is great", + response="Yes, it is!", + created_at=time.time(), + ) + backend.set("key1", entry) + + # Verify entry exists + assert backend.get("key1") is not None + + # Invalidate + invalidate("Python") + + # Entry should be gone + assert backend.get("key1") is None + + def test_export_includes_metadata(self, tmp_path): + """Test export includes all metadata fields.""" + backend = _stats_manager.get_backend() + backend.clear() + + entry = CacheEntry( + prompt="test prompt", + response="test response", + namespace="export_test", + created_at=time.time(), + ttl=7200, + hit_count=10, + input_tokens=500, + output_tokens=250, + embedding=[0.1, 0.2, 0.3], + ) + backend.set("key1", entry) + + entries = export_cache(namespace="export_test") + if entries: + e = entries[0] + assert "created_at" in e + assert "ttl" in e + assert "hit_count" in e + assert "input_tokens" in e + assert "output_tokens" in e