Add files via upload

initial commit
This commit is contained in:
Alpha Nerd 2026-03-06 15:54:47 +01:00 committed by GitHub
parent 8d3d5ff628
commit b33bb415dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 4840 additions and 0 deletions

View file

@ -0,0 +1,54 @@
"""
llm-semantic-cache: Semantic caching for LLM API calls.
Cut LLM costs 30% with one decorator.
"""
__version__ = "0.1.0"
__author__ = "Karthick Raja M"
__license__ = "MIT"
# Core exports
from semantic_llm_cache.config import CacheConfig
from semantic_llm_cache.core import (
CacheContext,
CachedLLM,
cache,
get_default_backend,
set_default_backend,
)
from semantic_llm_cache.exceptions import (
CacheBackendError,
CacheNotFoundError,
CacheSerializationError,
PromptCacheError,
)
from semantic_llm_cache.stats import CacheStats, clear_cache, get_stats, invalidate
from semantic_llm_cache.storage import StorageBackend
__all__ = [
# Version info
"__version__",
"__author__",
"__license__",
# Core API
"cache",
"CacheContext",
"CachedLLM",
"get_default_backend",
"set_default_backend",
# Storage
"StorageBackend",
# Statistics
"CacheStats",
"get_stats",
"clear_cache",
"invalidate",
# Configuration
"CacheConfig",
# Exceptions
"PromptCacheError",
"CacheBackendError",
"CacheSerializationError",
"CacheNotFoundError",
]

View file

@ -0,0 +1,21 @@
"""Storage backends for llm-semantic-cache."""
from semantic_llm_cache.backends.base import BaseBackend
from semantic_llm_cache.backends.memory import MemoryBackend
try:
from semantic_llm_cache.backends.sqlite import SQLiteBackend
except ImportError:
SQLiteBackend = None # type: ignore
try:
from semantic_llm_cache.backends.redis import RedisBackend
except ImportError:
RedisBackend = None # type: ignore
__all__ = [
"BaseBackend",
"MemoryBackend",
"SQLiteBackend",
"RedisBackend",
]

View file

@ -0,0 +1,104 @@
"""Base backend implementation with common functionality."""
import time
from typing import Any, Optional
import numpy as np
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.storage import StorageBackend
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
"""Calculate cosine similarity between two vectors.
Args:
a: First vector
b: Second vector
Returns:
Similarity score between 0 and 1
"""
a_arr = np.asarray(a, dtype=np.float32)
b_arr = np.asarray(b, dtype=np.float32)
dot_product = np.dot(a_arr, b_arr)
norm_a = np.linalg.norm(a_arr)
norm_b = np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(dot_product / (norm_a * norm_b))
class BaseBackend(StorageBackend):
"""Base backend with common sync helpers; async public interface via StorageBackend."""
def __init__(self) -> None:
"""Initialize base backend."""
self._hits: int = 0
self._misses: int = 0
def _increment_hits(self) -> None:
"""Increment hit counter."""
self._hits += 1
def _increment_misses(self) -> None:
"""Increment miss counter."""
self._misses += 1
def _check_expired(self, entry: CacheEntry) -> bool:
"""Check if entry is expired.
Args:
entry: CacheEntry to check
Returns:
True if expired, False otherwise
"""
return entry.is_expired(time.time())
def _find_best_match(
self,
candidates: list[tuple[str, CacheEntry]],
query_embedding: list[float],
threshold: float,
) -> Optional[tuple[str, CacheEntry, float]]:
"""Find best matching entry from candidates.
Sync helper CPU-only numpy ops, safe to call from async context.
Args:
candidates: List of (key, entry) tuples
query_embedding: Query embedding vector
threshold: Minimum similarity threshold
Returns:
(key, entry, similarity) tuple if found above threshold, None otherwise
"""
best_match: Optional[tuple[str, CacheEntry, float]] = None
best_similarity = threshold
for key, entry in candidates:
if entry.embedding is None:
continue
similarity = cosine_similarity(query_embedding, entry.embedding)
if similarity > best_similarity:
best_similarity = similarity
best_match = (key, entry, similarity)
return best_match
async def get_stats(self) -> dict[str, Any]:
"""Get backend statistics.
Returns:
Dictionary with hits and misses
"""
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": self._hits / max(self._hits + self._misses, 1),
}

View file

@ -0,0 +1,179 @@
"""In-memory storage backend."""
import sys
from typing import Any, Optional
from semantic_llm_cache.backends.base import BaseBackend
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.exceptions import CacheBackendError
class MemoryBackend(BaseBackend):
"""In-memory cache storage with LRU eviction.
All operations are in-memory dict access no I/O so async methods
run directly in the event loop without thread offloading.
"""
def __init__(self, max_size: Optional[int] = None) -> None:
"""Initialize memory backend.
Args:
max_size: Maximum number of entries to store (LRU eviction when reached)
"""
super().__init__()
self._cache: dict[str, CacheEntry] = {}
self._access_order: dict[str, float] = {}
self._max_size = max_size
self._access_counter: float = 0.0
def _evict_if_needed(self) -> None:
"""Evict oldest entry if at capacity."""
if self._max_size is None or len(self._cache) < self._max_size:
return
if self._access_order:
lru_key = min(self._access_order, key=lambda k: self._access_order.get(k, 0))
del self._cache[lru_key]
del self._access_order[lru_key]
def _update_access_time(self, key: str) -> None:
"""Update access time for LRU tracking."""
self._access_counter += 1
self._access_order[key] = self._access_counter
async def get(self, key: str) -> Optional[CacheEntry]:
"""Retrieve cache entry by key.
Args:
key: Cache key to retrieve
Returns:
CacheEntry if found and not expired, None otherwise
"""
try:
entry = self._cache.get(key)
if entry is None:
self._increment_misses()
return None
if self._check_expired(entry):
await self.delete(key)
self._increment_misses()
return None
self._increment_hits()
self._update_access_time(key)
entry.hit_count += 1
return entry
except Exception as e:
raise CacheBackendError(f"Failed to get entry: {e}") from e
async def set(self, key: str, entry: CacheEntry) -> None:
"""Store cache entry.
Args:
key: Cache key to store under
entry: CacheEntry to store
"""
try:
self._evict_if_needed()
self._cache[key] = entry
self._update_access_time(key)
except Exception as e:
raise CacheBackendError(f"Failed to set entry: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete cache entry.
Args:
key: Cache key to delete
Returns:
True if entry was deleted, False if not found
"""
try:
if key in self._cache:
del self._cache[key]
self._access_order.pop(key, None)
return True
return False
except Exception as e:
raise CacheBackendError(f"Failed to delete entry: {e}") from e
async def clear(self) -> None:
"""Clear all cache entries."""
try:
self._cache.clear()
self._access_order.clear()
except Exception as e:
raise CacheBackendError(f"Failed to clear cache: {e}") from e
async def iterate(
self, namespace: Optional[str] = None
) -> list[tuple[str, CacheEntry]]:
"""Iterate over cache entries, optionally filtered by namespace.
Args:
namespace: Optional namespace filter
Returns:
List of (key, entry) tuples
"""
try:
if namespace is None:
return list(self._cache.items())
return [
(k, v)
for k, v in self._cache.items()
if v.namespace == namespace and not self._check_expired(v)
]
except Exception as e:
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
async def find_similar(
self,
embedding: list[float],
threshold: float,
namespace: Optional[str] = None,
) -> Optional[tuple[str, CacheEntry, float]]:
"""Find semantically similar cached entry.
Args:
embedding: Query embedding vector
threshold: Minimum similarity score (0-1)
namespace: Optional namespace filter
Returns:
(key, entry, similarity) tuple if found above threshold, None otherwise
"""
try:
candidates = [
(k, v)
for k, v in self._cache.items()
if v.embedding is not None
and not self._check_expired(v)
and (namespace is None or v.namespace == namespace)
]
return self._find_best_match(candidates, embedding, threshold)
except Exception as e:
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
async def get_stats(self) -> dict[str, Any]:
"""Get backend statistics.
Returns:
Dictionary with size, memory usage, hits, misses
"""
base_stats = await super().get_stats()
memory_usage = sys.getsizeof(self._cache) + sum(
sys.getsizeof(k) + sys.getsizeof(v) for k, v in self._cache.items()
)
return {
**base_stats,
"size": len(self._cache),
"memory_bytes": memory_usage,
"max_size": self._max_size,
}

View file

@ -0,0 +1,239 @@
"""Redis distributed storage backend (async via redis.asyncio)."""
import json
from typing import Any, Optional
try:
from redis import asyncio as aioredis
except ImportError as err:
raise ImportError(
"Redis backend requires 'redis' package. "
"Install with: pip install semantic-llm-cache[redis]"
) from err
from semantic_llm_cache.backends.base import BaseBackend
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.exceptions import CacheBackendError
class RedisBackend(BaseBackend):
"""Redis-based distributed cache storage (async).
Uses redis.asyncio (bundled with redis>=4.2) for non-blocking I/O.
The connection is created in __init__; no explicit connect() call needed
as redis.asyncio uses a connection pool that connects lazily.
"""
DEFAULT_PREFIX = "semantic_llm_cache:"
def __init__(
self,
url: str = "redis://localhost:6379/0",
prefix: str = DEFAULT_PREFIX,
**kwargs: Any,
) -> None:
"""Initialize Redis backend.
Args:
url: Redis connection URL
prefix: Key prefix for cache entries
**kwargs: Additional arguments passed to redis.asyncio.from_url
"""
super().__init__()
self._prefix = prefix.rstrip(":") + ":"
self._redis = aioredis.from_url(url, **kwargs)
async def ping(self) -> None:
"""Test Redis connection. Call this after construction to verify connectivity.
Raises:
CacheBackendError: If Redis is not reachable
"""
try:
await self._redis.ping()
except Exception as e:
raise CacheBackendError(f"Failed to connect to Redis: {e}") from e
def _make_key(self, key: str) -> str:
"""Create full Redis key with prefix."""
return f"{self._prefix}{key}"
def _entry_to_dict(self, entry: CacheEntry) -> dict[str, Any]:
"""Convert CacheEntry to dictionary for storage."""
return {
"prompt": entry.prompt,
"response": entry.response,
"embedding": entry.embedding,
"created_at": entry.created_at,
"ttl": entry.ttl,
"namespace": entry.namespace,
"hit_count": entry.hit_count,
"input_tokens": entry.input_tokens,
"output_tokens": entry.output_tokens,
}
def _dict_to_entry(self, data: dict[str, Any]) -> CacheEntry:
"""Convert dictionary from storage to CacheEntry."""
return CacheEntry(
prompt=data["prompt"],
response=data["response"],
embedding=data.get("embedding"),
created_at=data["created_at"],
ttl=data.get("ttl"),
namespace=data.get("namespace", "default"),
hit_count=data.get("hit_count", 0),
input_tokens=data.get("input_tokens", 0),
output_tokens=data.get("output_tokens", 0),
)
async def get(self, key: str) -> Optional[CacheEntry]:
"""Retrieve cache entry by key.
Args:
key: Cache key to retrieve
Returns:
CacheEntry if found and not expired, None otherwise
"""
try:
redis_key = self._make_key(key)
data = await self._redis.get(redis_key)
if data is None:
self._increment_misses()
return None
entry_dict = json.loads(data)
entry = self._dict_to_entry(entry_dict)
if self._check_expired(entry):
await self.delete(key)
self._increment_misses()
return None
self._increment_hits()
entry.hit_count += 1
entry_dict["hit_count"] = entry.hit_count
await self._redis.set(redis_key, json.dumps(entry_dict))
return entry
except Exception as e:
raise CacheBackendError(f"Failed to get entry: {e}") from e
async def set(self, key: str, entry: CacheEntry) -> None:
"""Store cache entry.
Args:
key: Cache key to store under
entry: CacheEntry to store
"""
try:
redis_key = self._make_key(key)
data = json.dumps(self._entry_to_dict(entry))
redis_ttl = entry.ttl if entry.ttl is not None else 0
await self._redis.set(redis_key, data, ex=redis_ttl if redis_ttl > 0 else None)
except Exception as e:
raise CacheBackendError(f"Failed to set entry: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete cache entry.
Args:
key: Cache key to delete
Returns:
True if entry was deleted, False if not found
"""
try:
result = await self._redis.delete(self._make_key(key))
return result > 0
except Exception as e:
raise CacheBackendError(f"Failed to delete entry: {e}") from e
async def clear(self) -> None:
"""Clear all cache entries with this prefix."""
try:
keys = await self._redis.keys(f"{self._prefix}*")
if keys:
await self._redis.delete(*keys)
except Exception as e:
raise CacheBackendError(f"Failed to clear cache: {e}") from e
async def iterate(
self, namespace: Optional[str] = None
) -> list[tuple[str, CacheEntry]]:
"""Iterate over cache entries, optionally filtered by namespace.
Args:
namespace: Optional namespace filter
Returns:
List of (key, entry) tuples
"""
try:
keys = await self._redis.keys(f"{self._prefix}*")
results = []
for full_key in keys:
short_key = full_key.decode().replace(self._prefix, "", 1)
data = await self._redis.get(full_key)
if data:
entry_dict = json.loads(data)
entry = self._dict_to_entry(entry_dict)
if namespace is None or entry.namespace == namespace:
if not self._check_expired(entry):
results.append((short_key, entry))
return results
except Exception as e:
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
async def find_similar(
self,
embedding: list[float],
threshold: float,
namespace: Optional[str] = None,
) -> Optional[tuple[str, CacheEntry, float]]:
"""Find semantically similar cached entry.
Note: Loads all entries for cosine scan. For large datasets consider
Redis Stack with vector search (RediSearch).
Args:
embedding: Query embedding vector
threshold: Minimum similarity score (0-1)
namespace: Optional namespace filter
Returns:
(key, entry, similarity) tuple if found above threshold, None otherwise
"""
try:
entries = await self.iterate(namespace)
candidates = [(k, v) for k, v in entries if v.embedding is not None]
return self._find_best_match(candidates, embedding, threshold)
except Exception as e:
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
async def get_stats(self) -> dict[str, Any]:
"""Get backend statistics."""
base_stats = await super().get_stats()
try:
keys = await self._redis.keys(f"{self._prefix}*")
return {
**base_stats,
"size": len(keys) if keys else 0,
"prefix": self._prefix,
}
except Exception as e:
return {**base_stats, "size": 0, "prefix": self._prefix, "error": str(e)}
async def close(self) -> None:
"""Close Redis connection."""
try:
await self._redis.aclose()
except Exception:
pass

View file

@ -0,0 +1,279 @@
"""SQLite persistent storage backend (async via aiosqlite)."""
import json
from pathlib import Path
from typing import Any, Optional
try:
import aiosqlite
except ImportError as err:
raise ImportError(
"SQLite backend requires 'aiosqlite' package. "
"Install with: pip install semantic-llm-cache[sqlite]"
) from err
from semantic_llm_cache.backends.base import BaseBackend
from semantic_llm_cache.config import CacheEntry
from semantic_llm_cache.exceptions import CacheBackendError
class SQLiteBackend(BaseBackend):
"""SQLite-based persistent cache storage (async).
Uses aiosqlite for non-blocking I/O. A single persistent connection
is opened lazily on first use and reused for all subsequent operations.
"""
def __init__(self, db_path: str | Path = "semantic_cache.db") -> None:
"""Initialize SQLite backend.
Args:
db_path: Path to SQLite database file, or ":memory:" for in-memory DB
"""
super().__init__()
self._db_path = str(db_path) if isinstance(db_path, Path) else db_path
self._conn: Optional[aiosqlite.Connection] = None
async def _get_conn(self) -> aiosqlite.Connection:
"""Get or create the persistent async connection."""
if self._conn is None:
self._conn = await aiosqlite.connect(self._db_path)
self._conn.row_factory = aiosqlite.Row
await self._initialize_schema()
return self._conn
async def _initialize_schema(self) -> None:
"""Initialize database schema."""
conn = await self._get_conn()
await conn.execute(
"""
CREATE TABLE IF NOT EXISTS cache_entries (
key TEXT PRIMARY KEY,
prompt TEXT NOT NULL,
response TEXT NOT NULL,
embedding TEXT,
created_at REAL NOT NULL,
ttl INTEGER,
namespace TEXT NOT NULL DEFAULT 'default',
hit_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0
)
"""
)
await conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_namespace
ON cache_entries(namespace)
"""
)
await conn.commit()
def _row_to_entry(self, row: aiosqlite.Row) -> CacheEntry:
"""Convert database row to CacheEntry."""
embedding = None
if row["embedding"]:
embedding = json.loads(row["embedding"])
return CacheEntry(
prompt=row["prompt"],
response=json.loads(row["response"]),
embedding=embedding,
created_at=row["created_at"],
ttl=row["ttl"],
namespace=row["namespace"],
hit_count=row["hit_count"],
input_tokens=row["input_tokens"],
output_tokens=row["output_tokens"],
)
async def get(self, key: str) -> Optional[CacheEntry]:
"""Retrieve cache entry by key.
Args:
key: Cache key to retrieve
Returns:
CacheEntry if found and not expired, None otherwise
"""
try:
conn = await self._get_conn()
async with conn.execute(
"SELECT * FROM cache_entries WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
self._increment_misses()
return None
entry = self._row_to_entry(row)
if self._check_expired(entry):
await self.delete(key)
self._increment_misses()
return None
self._increment_hits()
entry.hit_count += 1
await conn.execute(
"UPDATE cache_entries SET hit_count = hit_count + 1 WHERE key = ?",
(key,),
)
await conn.commit()
return entry
except Exception as e:
raise CacheBackendError(f"Failed to get entry: {e}") from e
async def set(self, key: str, entry: CacheEntry) -> None:
"""Store cache entry.
Args:
key: Cache key to store under
entry: CacheEntry to store
"""
try:
conn = await self._get_conn()
embedding_json = json.dumps(entry.embedding) if entry.embedding else None
response_json = json.dumps(entry.response)
await conn.execute(
"""
INSERT OR REPLACE INTO cache_entries
(key, prompt, response, embedding, created_at, ttl, namespace,
hit_count, input_tokens, output_tokens)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
key,
entry.prompt,
response_json,
embedding_json,
entry.created_at,
entry.ttl,
entry.namespace,
entry.hit_count,
entry.input_tokens,
entry.output_tokens,
),
)
await conn.commit()
except Exception as e:
raise CacheBackendError(f"Failed to set entry: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete cache entry.
Args:
key: Cache key to delete
Returns:
True if entry was deleted, False if not found
"""
try:
conn = await self._get_conn()
async with conn.execute(
"DELETE FROM cache_entries WHERE key = ?", (key,)
) as cursor:
rowcount = cursor.rowcount
await conn.commit()
return rowcount > 0
except Exception as e:
raise CacheBackendError(f"Failed to delete entry: {e}") from e
async def clear(self) -> None:
"""Clear all cache entries."""
try:
conn = await self._get_conn()
await conn.execute("DELETE FROM cache_entries")
await conn.commit()
except Exception as e:
raise CacheBackendError(f"Failed to clear cache: {e}") from e
async def iterate(
self, namespace: Optional[str] = None
) -> list[tuple[str, CacheEntry]]:
"""Iterate over cache entries, optionally filtered by namespace.
Args:
namespace: Optional namespace filter
Returns:
List of (key, entry) tuples
"""
try:
conn = await self._get_conn()
if namespace is None:
query = "SELECT key, * FROM cache_entries"
params: tuple[()] = ()
else:
query = "SELECT key, * FROM cache_entries WHERE namespace = ?"
params = (namespace,)
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
results = []
for row in rows:
key = row["key"]
entry = self._row_to_entry(row)
if not self._check_expired(entry):
results.append((key, entry))
return results
except Exception as e:
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
async def find_similar(
self,
embedding: list[float],
threshold: float,
namespace: Optional[str] = None,
) -> Optional[tuple[str, CacheEntry, float]]:
"""Find semantically similar cached entry.
Args:
embedding: Query embedding vector
threshold: Minimum similarity score (0-1)
namespace: Optional namespace filter
Returns:
(key, entry, similarity) tuple if found above threshold, None otherwise
"""
try:
entries = await self.iterate(namespace)
candidates = [(k, v) for k, v in entries if v.embedding is not None]
return self._find_best_match(candidates, embedding, threshold)
except Exception as e:
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
async def get_stats(self) -> dict[str, Any]:
"""Get backend statistics.
Returns:
Dictionary with size, database path, hits, misses
"""
base_stats = await super().get_stats()
try:
conn = await self._get_conn()
async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
row = await cursor.fetchone()
size = row[0] if row else 0
return {
**base_stats,
"size": size,
"db_path": self._db_path,
}
except Exception as e:
return {**base_stats, "size": 0, "db_path": self._db_path, "error": str(e)}
async def close(self) -> None:
"""Close database connection."""
if self._conn is not None:
await self._conn.close()
self._conn = None

View file

@ -0,0 +1,61 @@
"""Configuration management for prompt-cache."""
from dataclasses import dataclass
from typing import Any, Callable, Optional
@dataclass
class CacheConfig:
"""Configuration for cache behavior."""
similarity_threshold: float = 1.0 # 1.0 = exact match, lower = semantic
ttl: Optional[int] = 3600 # Time to live in seconds, None = forever
namespace: str = "default" # Isolate different use cases
enabled: bool = True # Enable/disable caching
key_func: Optional[Callable[[Any], str]] = None # Custom cache key function
# Cost estimation for statistics (USD per 1K tokens)
input_cost_per_1k: float = 0.001 # Default ~$1/1M for cheaper models
output_cost_per_1k: float = 0.002 # Default ~$2/1M for cheaper models
# Performance settings
max_cache_size: Optional[int] = None # LRU eviction when set
embedding_model: str = "all-MiniLM-L6-v2" # Default sentence-transformer model
def __post_init__(self) -> None:
"""Validate configuration."""
if not 0.0 <= self.similarity_threshold <= 1.0:
raise ValueError("similarity_threshold must be between 0.0 and 1.0")
if self.ttl is not None and self.ttl <= 0:
raise ValueError("ttl must be positive or None")
if self.max_cache_size is not None and self.max_cache_size <= 0:
raise ValueError("max_cache_size must be positive or None")
@dataclass
class CacheEntry:
"""A cached response with metadata."""
prompt: str
response: Any
embedding: Optional[list[float]] = None # Normalized embedding vector
created_at: float = 0.0 # Unix timestamp
ttl: Optional[int] = None # Time to live in seconds
namespace: str = "default"
hit_count: int = 0
# Approximate token counts for cost estimation
input_tokens: int = 0
output_tokens: int = 0
def is_expired(self, current_time: float) -> bool:
"""Check if entry has expired based on TTL."""
if self.ttl is None:
return False
return (current_time - self.created_at) > self.ttl
def estimate_cost(self, input_cost: float, output_cost: float) -> float:
"""Estimate cost savings in USD."""
input_savings = (self.input_tokens / 1000) * input_cost
output_savings = (self.output_tokens / 1000) * output_cost
return input_savings + output_savings

369
semantic_llm_cache/core.py Normal file
View file

@ -0,0 +1,369 @@
"""Core cache decorator and API for llm-semantic-cache."""
import functools
import inspect
import time
from typing import Any, Callable, Optional, ParamSpec, TypeVar
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.backends.base import BaseBackend
from semantic_llm_cache.config import CacheConfig, CacheEntry
from semantic_llm_cache.exceptions import PromptCacheError
from semantic_llm_cache.similarity import EmbeddingCache
from semantic_llm_cache.stats import _stats_manager
from semantic_llm_cache.utils import hash_prompt, normalize_prompt
P = ParamSpec("P")
R = TypeVar("R")
def _extract_prompt(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
"""Extract prompt string from function arguments."""
if args and isinstance(args[0], str):
return args[0]
if "prompt" in kwargs:
return str(kwargs["prompt"])
return str(args) + str(sorted(kwargs.items()))
class CacheContext:
"""Context manager for cache configuration.
Supports both sync (with) and async (async with) usage.
Examples:
>>> async with CacheContext(similarity=0.9) as ctx:
... result = await llm_call("prompt")
... print(ctx.stats)
"""
def __init__(
self,
similarity: Optional[float] = None,
ttl: Optional[int] = None,
namespace: Optional[str] = None,
enabled: Optional[bool] = None,
) -> None:
self._config = CacheConfig(
similarity_threshold=similarity if similarity is not None else 1.0,
ttl=ttl,
namespace=namespace if namespace is not None else "default",
enabled=enabled if enabled is not None else True,
)
self._stats: dict[str, Any] = {"hits": 0, "misses": 0}
def __enter__(self) -> "CacheContext":
return self
def __exit__(self, *args: Any) -> None:
pass
async def __aenter__(self) -> "CacheContext":
return self
async def __aexit__(self, *args: Any) -> None:
pass
@property
def stats(self) -> dict[str, Any]:
return self._stats.copy()
@property
def config(self) -> CacheConfig:
return self._config
class CachedLLM:
"""Wrapper class for LLM calls with automatic caching.
Examples:
>>> llm = CachedLLM(similarity=0.9)
>>> response = await llm.achat("What is Python?", llm_func=my_async_llm)
"""
def __init__(
self,
provider: str = "openai",
model: str = "gpt-4",
similarity: float = 1.0,
ttl: Optional[int] = 3600,
backend: Optional[BaseBackend] = None,
namespace: str = "default",
enabled: bool = True,
) -> None:
self._provider = provider
self._model = model
self._backend = backend or MemoryBackend()
self._embedding_cache = EmbeddingCache()
self._config = CacheConfig(
similarity_threshold=similarity,
ttl=ttl,
namespace=namespace,
enabled=enabled,
)
async def achat(
self,
prompt: str,
llm_func: Optional[Callable[[str], Any]] = None,
**kwargs: Any,
) -> Any:
"""Get response with caching (async).
Args:
prompt: Input prompt
llm_func: Async or sync LLM function to call on cache miss
**kwargs: Additional arguments for llm_func
Returns:
LLM response (cached or fresh)
"""
if llm_func is None:
raise ValueError("llm_func is required for CachedLLM.achat()")
@cache(
similarity=self._config.similarity_threshold,
ttl=self._config.ttl,
backend=self._backend,
namespace=self._config.namespace,
enabled=self._config.enabled,
)
async def _cached_call(p: str) -> Any:
result = llm_func(p, **kwargs)
if inspect.isawaitable(result):
return await result
return result
return await _cached_call(prompt)
def cache(
similarity: float = 1.0,
ttl: Optional[int] = 3600,
backend: Optional[BaseBackend] = None,
namespace: str = "default",
enabled: bool = True,
key_func: Optional[Callable[..., str]] = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator for caching LLM function responses.
Auto-detects whether the decorated function is async or sync and returns
the appropriate wrapper. Both variants share identical cache logic.
Async functions get a true async wrapper (awaits all backend calls).
Sync functions get a sync wrapper that drives the async backends via a
temporary event loop not suitable inside a running loop; prefer decorating
async functions when integrating with async frameworks like FastAPI.
Args:
similarity: Cosine similarity threshold (1.0=exact, 0.9=semantic)
ttl: Time-to-live in seconds (None=forever)
backend: Async storage backend (None=in-memory)
namespace: Cache namespace for isolation
enabled: Whether caching is enabled
key_func: Custom cache key function
Returns:
Decorated function with caching
Examples:
>>> @cache(similarity=0.9, ttl=3600)
... async def ask_llm(prompt: str) -> str:
... return await call_ollama(prompt)
>>> @cache()
... def ask_llm_sync(prompt: str) -> str:
... return call_ollama_sync(prompt)
"""
_backend = backend or MemoryBackend()
embedding_cache = EmbeddingCache()
def decorator(func: Callable[P, R]) -> Callable[P, R]:
if inspect.iscoroutinefunction(func):
# ── Async wrapper ────────────────────────────────────────────────
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if not enabled:
return await func(*args, **kwargs)
start_time = time.time()
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
normalized = normalize_prompt(prompt)
cache_key = (
key_func(*args, **kwargs) # type: ignore[arg-type]
if key_func
else hash_prompt(normalized, namespace)
)
# 1. Exact match
entry = await _backend.get(cache_key)
if entry is not None:
latency_ms = (time.time() - start_time) * 1000
_stats_manager.record_hit(
namespace,
latency_saved_ms=latency_ms,
saved_cost=entry.estimate_cost(0.001, 0.002),
)
return entry.response # type: ignore[return-value]
# 2. Semantic match
if similarity < 1.0:
query_embedding = await embedding_cache.aencode(normalized)
result = await _backend.find_similar(
query_embedding, threshold=similarity, namespace=namespace
)
if result is not None:
_, matched_entry, _ = result
latency_ms = (time.time() - start_time) * 1000
_stats_manager.record_hit(
namespace,
latency_saved_ms=latency_ms,
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
)
return matched_entry.response # type: ignore[return-value]
# 3. Cache miss — call through
_stats_manager.record_miss(namespace)
try:
response = await func(*args, **kwargs)
except Exception as e:
raise PromptCacheError(f"LLM function call failed: {e}") from e
embedding = None
if similarity < 1.0:
embedding = await embedding_cache.aencode(normalized)
await _backend.set(
cache_key,
CacheEntry(
prompt=normalized,
response=response,
embedding=embedding,
created_at=time.time(),
ttl=ttl,
namespace=namespace,
hit_count=0,
input_tokens=len(normalized) // 4,
output_tokens=len(str(response)) // 4,
),
)
return response # type: ignore[return-value]
return async_wrapper # type: ignore[return-value]
else:
# ── Sync wrapper (backwards compatibility) ───────────────────────
# Drives async backends via a dedicated event loop per call.
# Do NOT use inside a running event loop (e.g. FastAPI handlers).
import asyncio
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if not enabled:
return func(*args, **kwargs)
start_time = time.time()
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
normalized = normalize_prompt(prompt)
cache_key = (
key_func(*args, **kwargs) # type: ignore[arg-type]
if key_func
else hash_prompt(normalized, namespace)
)
loop = asyncio.new_event_loop()
try:
# 1. Exact match
entry = loop.run_until_complete(_backend.get(cache_key))
if entry is not None:
latency_ms = (time.time() - start_time) * 1000
_stats_manager.record_hit(
namespace,
latency_saved_ms=latency_ms,
saved_cost=entry.estimate_cost(0.001, 0.002),
)
return entry.response # type: ignore[return-value]
# 2. Semantic match
if similarity < 1.0:
query_embedding = embedding_cache.encode(normalized)
result = loop.run_until_complete(
_backend.find_similar(
query_embedding, threshold=similarity, namespace=namespace
)
)
if result is not None:
_, matched_entry, _ = result
latency_ms = (time.time() - start_time) * 1000
_stats_manager.record_hit(
namespace,
latency_saved_ms=latency_ms,
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
)
return matched_entry.response # type: ignore[return-value]
# 3. Cache miss
_stats_manager.record_miss(namespace)
try:
response = func(*args, **kwargs)
except Exception as e:
raise PromptCacheError(f"LLM function call failed: {e}") from e
embedding = None
if similarity < 1.0:
embedding = embedding_cache.encode(normalized)
loop.run_until_complete(
_backend.set(
cache_key,
CacheEntry(
prompt=normalized,
response=response,
embedding=embedding,
created_at=time.time(),
ttl=ttl,
namespace=namespace,
hit_count=0,
input_tokens=len(normalized) // 4,
output_tokens=len(str(response)) // 4,
),
)
)
return response # type: ignore[return-value]
finally:
loop.close()
return sync_wrapper # type: ignore[return-value]
return decorator
# Global default backend for utility functions
_default_backend: Optional[BaseBackend] = None
def get_default_backend() -> BaseBackend:
"""Get default storage backend."""
global _default_backend
if _default_backend is None:
_default_backend = MemoryBackend()
return _default_backend
def set_default_backend(backend: BaseBackend) -> None:
"""Set default storage backend."""
global _default_backend
_default_backend = backend
_stats_manager.set_backend(backend)
__all__ = [
"cache",
"CacheContext",
"CachedLLM",
"get_default_backend",
"set_default_backend",
]

View file

@ -0,0 +1,25 @@
"""Custom exceptions for prompt-cache."""
class PromptCacheError(Exception):
"""Base exception for prompt-cache errors."""
pass
class CacheBackendError(PromptCacheError):
"""Exception raised when backend operations fail."""
pass
class CacheSerializationError(PromptCacheError):
"""Exception raised when serialization/deserialization fails."""
pass
class CacheNotFoundError(PromptCacheError):
"""Exception raised when cache entry is not found."""
pass

View file

@ -0,0 +1 @@
# PEP 561 marker file for type hints

View file

@ -0,0 +1,283 @@
"""Embedding generation and similarity matching for llm-semantic-cache."""
import asyncio
import hashlib
from functools import lru_cache
from typing import Optional
import numpy as np
from semantic_llm_cache.exceptions import PromptCacheError
class EmbeddingProvider:
"""Base class for embedding providers."""
def encode(self, text: str) -> list[float]:
"""Generate embedding for text.
Args:
text: Input text to encode
Returns:
Embedding vector as list of floats
"""
raise NotImplementedError
class DummyEmbeddingProvider(EmbeddingProvider):
"""Fallback embedding provider using hash-based vectors.
Provides consistent embeddings without external dependencies.
Not semantically meaningful but provides consistent cache keys.
"""
def __init__(self, dim: int = 384) -> None:
"""Initialize dummy provider.
Args:
dim: Embedding dimension (matches MiniLM default)
"""
self._dim = dim
def encode(self, text: str) -> list[float]:
"""Generate hash-based embedding for text.
Args:
text: Input text to encode
Returns:
Deterministic embedding vector based on text hash
"""
hash_obj = hashlib.sha256(text.encode())
hash_bytes = hash_obj.digest()
values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype(
np.float32
)
if len(values) < self._dim:
values = np.pad(values, (0, self._dim - len(values)))
norm = np.linalg.norm(values)
if norm > 0:
values = values / norm
return values.tolist()
class SentenceTransformerProvider(EmbeddingProvider):
"""Sentence-transformers based embedding provider.
Uses local models like MiniLM for semantic embeddings.
Inference is CPU/GPU-bound; use aencode() from async contexts.
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
"""Initialize sentence-transformer provider.
Args:
model_name: Name of sentence-transformer model
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise PromptCacheError(
"sentence-transformers package required for semantic matching. "
"Install with: pip install semantic-llm-cache[semantic]"
) from e
self._model = SentenceTransformer(model_name)
self._dim = self._model.get_sentence_embedding_dimension()
def encode(self, text: str) -> list[float]:
"""Generate embedding for text (blocking — use aencode from async code).
Args:
text: Input text to encode
Returns:
Normalized embedding vector
"""
embedding = self._model.encode(text, convert_to_numpy=True)
embedding = np.asarray(embedding, dtype=np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding.tolist()
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI API-based embedding provider.
Uses OpenAI's embedding API for high-quality semantic embeddings.
Network I/O always use aencode() from async contexts.
"""
def __init__(
self, api_key: Optional[str] = None, model: str = "text-embedding-3-small"
) -> None:
"""Initialize OpenAI embedding provider.
Args:
api_key: OpenAI API key (uses OPENAI_API_KEY env var if None)
model: OpenAI embedding model to use
"""
try:
import openai
except ImportError as e:
raise PromptCacheError(
"openai package required for OpenAI embeddings. "
"Install with: pip install semantic-llm-cache[openai]"
) from e
self._client = openai.OpenAI(api_key=api_key)
self._model = model
def encode(self, text: str) -> list[float]:
"""Generate embedding for text (blocking — use aencode from async code).
Args:
text: Input text to encode
Returns:
OpenAI embedding vector (already normalized)
"""
response = self._client.embeddings.create(input=text, model=self._model)
embedding = response.data[0].embedding
embedding_arr = np.asarray(embedding, dtype=np.float32)
norm = np.linalg.norm(embedding_arr)
if norm > 0:
embedding_arr = embedding_arr / norm
return embedding_arr.tolist()
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
"""Calculate cosine similarity between two vectors.
Args:
a: First vector
b: Second vector
Returns:
Similarity score between 0 and 1
Raises:
ValueError: If vectors have different dimensions
"""
a_arr = np.asarray(a, dtype=np.float32)
b_arr = np.asarray(b, dtype=np.float32)
if a_arr.shape != b_arr.shape:
raise ValueError(
f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}"
)
dot_product = np.dot(a_arr, b_arr)
norm_a = np.linalg.norm(a_arr)
norm_b = np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(dot_product / (norm_a * norm_b))
def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]:
"""Helper function for LRU cache encoding.
Args:
text: Input text
provider: Embedding provider
Returns:
Embedding as tuple for hashability
"""
return tuple(provider.encode(text))
class EmbeddingCache:
"""Cache for embedding generation with LRU eviction.
Use encode() from sync contexts, aencode() from async contexts.
aencode() offloads blocking inference to a thread pool via asyncio.to_thread.
"""
def __init__(
self,
provider: Optional[EmbeddingProvider] = None,
cache_size: int = 1024,
) -> None:
"""Initialize embedding cache.
Args:
provider: Embedding provider (uses DummyEmbeddingProvider if None)
cache_size: Maximum number of embeddings to cache
"""
self._provider = provider or DummyEmbeddingProvider()
self._cache_size = cache_size
self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider)
def encode(self, text: str) -> list[float]:
"""Generate embedding with LRU caching (sync, blocking).
Args:
text: Input text to encode
Returns:
Embedding vector
"""
return list(self._get_cached(text, self._provider))
async def aencode(self, text: str) -> list[float]:
"""Generate embedding with LRU caching (async, non-blocking).
CPU/network-bound work is offloaded to the default thread pool via
asyncio.to_thread, keeping the event loop free.
Args:
text: Input text to encode
Returns:
Embedding vector
"""
return await asyncio.to_thread(self.encode, text)
def clear_cache(self) -> None:
"""Clear the embedding LRU cache."""
self._get_cached.cache_clear()
def create_embedding_provider(
provider_type: str = "auto",
model_name: Optional[str] = None,
) -> EmbeddingProvider:
"""Create embedding provider based on type.
Args:
provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy")
model_name: Optional model name to use
Returns:
EmbeddingProvider instance
"""
if provider_type == "auto":
try:
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
except PromptCacheError:
return DummyEmbeddingProvider()
if provider_type == "sentence-transformer":
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
if provider_type == "openai":
return OpenAIEmbeddingProvider(model=model_name)
if provider_type == "dummy":
return DummyEmbeddingProvider()
raise ValueError(f"Unknown provider type: {provider_type}")

255
semantic_llm_cache/stats.py Normal file
View file

@ -0,0 +1,255 @@
"""Statistics and analytics for llm-semantic-cache."""
from dataclasses import dataclass
from threading import Lock
from typing import Any, Callable, Optional
from semantic_llm_cache.backends import MemoryBackend
from semantic_llm_cache.backends.base import BaseBackend
@dataclass
class CacheStats:
"""Statistics for cache performance."""
hits: int = 0
misses: int = 0
total_saved_ms: float = 0.0
estimated_savings_usd: float = 0.0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
return self.hits / max(total, 1)
@property
def total_requests(self) -> int:
"""Get total number of requests."""
return self.hits + self.misses
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": self.hit_rate,
"total_requests": self.total_requests,
"total_saved_ms": self.total_saved_ms,
"estimated_savings_usd": self.estimated_savings_usd,
}
def __iadd__(self, other: "CacheStats") -> "CacheStats":
self.hits += other.hits
self.misses += other.misses
self.total_saved_ms += other.total_saved_ms
self.estimated_savings_usd += other.estimated_savings_usd
return self
class _StatsManager:
"""Manager for global cache statistics.
Uses threading.Lock for record_hit/record_miss these are simple counter
increments with no awaits inside the lock, so threading.Lock is safe and
avoids the overhead of asyncio.Lock for hot-path calls.
"""
def __init__(self) -> None:
"""Initialize stats manager."""
self._stats: dict[str, CacheStats] = {}
self._lock = Lock()
self._default_backend: Optional[BaseBackend] = None
def get_backend(self) -> BaseBackend:
"""Get default backend for cache operations."""
if self._default_backend is None:
self._default_backend = MemoryBackend()
return self._default_backend
def set_backend(self, backend: BaseBackend) -> None:
"""Set default backend for cache operations."""
with self._lock:
self._default_backend = backend
def record_hit(
self,
namespace: str,
latency_saved_ms: float = 0.0,
saved_cost: float = 0.0,
) -> None:
"""Record a cache hit (sync, safe to call from async context)."""
with self._lock:
if namespace not in self._stats:
self._stats[namespace] = CacheStats()
stats = self._stats[namespace]
stats.hits += 1
stats.total_saved_ms += latency_saved_ms
stats.estimated_savings_usd += saved_cost
def record_miss(self, namespace: str) -> None:
"""Record a cache miss (sync, safe to call from async context)."""
with self._lock:
if namespace not in self._stats:
self._stats[namespace] = CacheStats()
self._stats[namespace].misses += 1
def get_stats(self, namespace: Optional[str] = None) -> CacheStats:
"""Get statistics for namespace or all."""
with self._lock:
if namespace is not None:
return self._stats.get(namespace, CacheStats())
total = CacheStats()
for stats in self._stats.values():
total += stats
return total
def clear_stats(self, namespace: Optional[str] = None) -> None:
"""Clear statistics for namespace or all."""
with self._lock:
if namespace is None:
self._stats.clear()
elif namespace in self._stats:
del self._stats[namespace]
# Global stats manager instance
_stats_manager = _StatsManager()
def get_stats(namespace: Optional[str] = None) -> dict[str, Any]:
"""Get cache statistics (sync).
Args:
namespace: Optional namespace to filter by
Returns:
Dictionary with cache statistics
"""
return _stats_manager.get_stats(namespace).to_dict()
async def clear_cache(namespace: Optional[str] = None) -> int:
"""Clear all cached entries (async).
Args:
namespace: Optional namespace to clear (None = all)
Returns:
Number of entries cleared
"""
backend = _stats_manager.get_backend()
if namespace is None:
stats = await backend.get_stats()
size = stats.get("size", 0)
await backend.clear()
_stats_manager.clear_stats()
return size
entries = await backend.iterate(namespace=namespace)
count = len(entries)
for key, _ in entries:
await backend.delete(key)
_stats_manager.clear_stats(namespace)
return count
async def invalidate(
pattern: str,
namespace: Optional[str] = None,
) -> int:
"""Invalidate cache entries matching pattern (async).
Args:
pattern: String pattern to match in prompts
namespace: Optional namespace filter
Returns:
Number of entries invalidated
"""
backend = _stats_manager.get_backend()
entries = await backend.iterate(namespace=namespace)
count = 0
pattern_lower = pattern.lower()
for key, entry in entries:
if pattern_lower in entry.prompt.lower():
await backend.delete(key)
count += 1
return count
async def warm_cache(
prompts: list[str],
llm_func: Callable[[str], Any],
namespace: str = "default",
) -> int:
"""Pre-populate cache with prompts (async).
Args:
prompts: List of prompts to cache
llm_func: Async or sync LLM function to call for each prompt
namespace: Cache namespace to use
Returns:
Number of prompts attempted
"""
import asyncio
import inspect
from semantic_llm_cache.core import cache
cached_func = cache(namespace=namespace)(llm_func)
for prompt in prompts:
try:
result = cached_func(prompt)
if inspect.isawaitable(result):
await result
except Exception:
pass
return len(prompts)
async def export_cache(
namespace: Optional[str] = None,
filepath: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Export cache entries for analysis (async).
Args:
namespace: Optional namespace filter
filepath: Optional file path to save export (JSON)
Returns:
List of cache entry dictionaries
"""
import json
from datetime import datetime
backend = _stats_manager.get_backend()
entries = await backend.iterate(namespace=namespace)
export_data = []
for key, entry in entries:
export_data.append({
"key": key,
"prompt": entry.prompt,
"response": str(entry.response)[:1000],
"namespace": entry.namespace,
"hit_count": entry.hit_count,
"created_at": datetime.fromtimestamp(entry.created_at).isoformat(),
"ttl": entry.ttl,
"input_tokens": entry.input_tokens,
"output_tokens": entry.output_tokens,
})
if filepath:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(export_data, f, indent=2)
return export_data

View file

@ -0,0 +1,111 @@
"""Storage backend interface for prompt-cache."""
from abc import ABC, abstractmethod
from typing import Any, Optional
from semantic_llm_cache.config import CacheEntry
class StorageBackend(ABC):
"""Abstract base class for async cache storage backends."""
@abstractmethod
async def get(self, key: str) -> Optional[CacheEntry]:
"""Retrieve cache entry by key.
Args:
key: Cache key to retrieve
Returns:
CacheEntry if found and not expired, None otherwise
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def set(self, key: str, entry: CacheEntry) -> None:
"""Store cache entry.
Args:
key: Cache key to store under
entry: CacheEntry to store
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete cache entry.
Args:
key: Cache key to delete
Returns:
True if entry was deleted, False if not found
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def clear(self) -> None:
"""Clear all cache entries.
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def iterate(self, namespace: Optional[str] = None) -> list[tuple[str, CacheEntry]]:
"""Iterate over cache entries, optionally filtered by namespace.
Args:
namespace: Optional namespace filter
Returns:
List of (key, entry) tuples
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def find_similar(
self,
embedding: list[float],
threshold: float,
namespace: Optional[str] = None,
) -> Optional[tuple[str, CacheEntry, float]]:
"""Find semantically similar cached entry.
Args:
embedding: Query embedding vector
threshold: Minimum similarity score (0-1)
namespace: Optional namespace filter
Returns:
(key, entry, similarity) tuple if found above threshold, None otherwise
Raises:
CacheBackendError: If backend operation fails
"""
pass
@abstractmethod
async def get_stats(self) -> dict[str, Any]:
"""Get backend statistics.
Returns:
Dictionary with stats like size, memory_usage, etc.
Raises:
CacheBackendError: If backend operation fails
"""
pass

View file

@ -0,0 +1,97 @@
"""Utility functions for prompt-cache."""
import hashlib
import re
from typing import Any
def normalize_prompt(prompt: str) -> str:
"""Normalize prompt text for consistent caching.
Args:
prompt: Raw prompt text
Returns:
Normalized prompt text
"""
# Remove extra whitespace
prompt = " ".join(prompt.split())
# Lowercase for better matching (optional - can affect semantics)
# prompt = prompt.lower()
# Remove common filler words at start
filler_pattern = r"^(please|can you|could you|i need|i want)\s+"
prompt = re.sub(filler_pattern, "", prompt, flags=re.IGNORECASE)
# Normalize quotes
prompt = prompt.replace('"', "'").replace("`", "'")
# Remove trailing punctuation
prompt = prompt.rstrip("?!.")
return prompt.strip()
def hash_prompt(prompt: str, namespace: str = "default") -> str:
"""Generate cache key from prompt and namespace.
Args:
prompt: Prompt text
namespace: Cache namespace
Returns:
Hash-based cache key
"""
combined = f"{namespace}:{prompt}"
return hashlib.sha256(combined.encode()).hexdigest()
def estimate_tokens(text: str) -> int:
"""Estimate token count for text (rough approximation).
Args:
text: Input text
Returns:
Estimated token count
"""
# Rough approximation: ~4 chars per token
return len(text) // 4
def serialize_response(response: Any) -> str:
"""Serialize response for storage.
Args:
response: Response object (string, dict, etc.)
Returns:
Serialized JSON string
"""
import json
return json.dumps(response)
def deserialize_response(data: str) -> Any:
"""Deserialize response from storage.
Args:
data: Serialized JSON string
Returns:
Deserialized response object
"""
import json
return json.loads(data)
__all__ = [
"normalize_prompt",
"hash_prompt",
"estimate_tokens",
"serialize_response",
"deserialize_response",
]