Add files via upload
initial commit
This commit is contained in:
parent
8d3d5ff628
commit
b33bb415dd
24 changed files with 4840 additions and 0 deletions
54
semantic_llm_cache/__init__.py
Normal file
54
semantic_llm_cache/__init__.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
llm-semantic-cache: Semantic caching for LLM API calls.
|
||||
|
||||
Cut LLM costs 30% with one decorator.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "Karthick Raja M"
|
||||
__license__ = "MIT"
|
||||
|
||||
# Core exports
|
||||
from semantic_llm_cache.config import CacheConfig
|
||||
from semantic_llm_cache.core import (
|
||||
CacheContext,
|
||||
CachedLLM,
|
||||
cache,
|
||||
get_default_backend,
|
||||
set_default_backend,
|
||||
)
|
||||
from semantic_llm_cache.exceptions import (
|
||||
CacheBackendError,
|
||||
CacheNotFoundError,
|
||||
CacheSerializationError,
|
||||
PromptCacheError,
|
||||
)
|
||||
from semantic_llm_cache.stats import CacheStats, clear_cache, get_stats, invalidate
|
||||
from semantic_llm_cache.storage import StorageBackend
|
||||
|
||||
__all__ = [
|
||||
# Version info
|
||||
"__version__",
|
||||
"__author__",
|
||||
"__license__",
|
||||
# Core API
|
||||
"cache",
|
||||
"CacheContext",
|
||||
"CachedLLM",
|
||||
"get_default_backend",
|
||||
"set_default_backend",
|
||||
# Storage
|
||||
"StorageBackend",
|
||||
# Statistics
|
||||
"CacheStats",
|
||||
"get_stats",
|
||||
"clear_cache",
|
||||
"invalidate",
|
||||
# Configuration
|
||||
"CacheConfig",
|
||||
# Exceptions
|
||||
"PromptCacheError",
|
||||
"CacheBackendError",
|
||||
"CacheSerializationError",
|
||||
"CacheNotFoundError",
|
||||
]
|
||||
21
semantic_llm_cache/backends/__init__.py
Normal file
21
semantic_llm_cache/backends/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Storage backends for llm-semantic-cache."""
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.backends.memory import MemoryBackend
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.sqlite import SQLiteBackend
|
||||
except ImportError:
|
||||
SQLiteBackend = None # type: ignore
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.redis import RedisBackend
|
||||
except ImportError:
|
||||
RedisBackend = None # type: ignore
|
||||
|
||||
__all__ = [
|
||||
"BaseBackend",
|
||||
"MemoryBackend",
|
||||
"SQLiteBackend",
|
||||
"RedisBackend",
|
||||
]
|
||||
104
semantic_llm_cache/backends/base.py
Normal file
104
semantic_llm_cache/backends/base.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Base backend implementation with common functionality."""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.storage import StorageBackend
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
class BaseBackend(StorageBackend):
|
||||
"""Base backend with common sync helpers; async public interface via StorageBackend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize base backend."""
|
||||
self._hits: int = 0
|
||||
self._misses: int = 0
|
||||
|
||||
def _increment_hits(self) -> None:
|
||||
"""Increment hit counter."""
|
||||
self._hits += 1
|
||||
|
||||
def _increment_misses(self) -> None:
|
||||
"""Increment miss counter."""
|
||||
self._misses += 1
|
||||
|
||||
def _check_expired(self, entry: CacheEntry) -> bool:
|
||||
"""Check if entry is expired.
|
||||
|
||||
Args:
|
||||
entry: CacheEntry to check
|
||||
|
||||
Returns:
|
||||
True if expired, False otherwise
|
||||
"""
|
||||
return entry.is_expired(time.time())
|
||||
|
||||
def _find_best_match(
|
||||
self,
|
||||
candidates: list[tuple[str, CacheEntry]],
|
||||
query_embedding: list[float],
|
||||
threshold: float,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find best matching entry from candidates.
|
||||
|
||||
Sync helper — CPU-only numpy ops, safe to call from async context.
|
||||
|
||||
Args:
|
||||
candidates: List of (key, entry) tuples
|
||||
query_embedding: Query embedding vector
|
||||
threshold: Minimum similarity threshold
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
best_match: Optional[tuple[str, CacheEntry, float]] = None
|
||||
best_similarity = threshold
|
||||
|
||||
for key, entry in candidates:
|
||||
if entry.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(query_embedding, entry.embedding)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = (key, entry, similarity)
|
||||
|
||||
return best_match
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with hits and misses
|
||||
"""
|
||||
return {
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": self._hits / max(self._hits + self._misses, 1),
|
||||
}
|
||||
179
semantic_llm_cache/backends/memory.py
Normal file
179
semantic_llm_cache/backends/memory.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""In-memory storage backend."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class MemoryBackend(BaseBackend):
|
||||
"""In-memory cache storage with LRU eviction.
|
||||
|
||||
All operations are in-memory dict access — no I/O — so async methods
|
||||
run directly in the event loop without thread offloading.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: Optional[int] = None) -> None:
|
||||
"""Initialize memory backend.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store (LRU eviction when reached)
|
||||
"""
|
||||
super().__init__()
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._access_order: dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._access_counter: float = 0.0
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entry if at capacity."""
|
||||
if self._max_size is None or len(self._cache) < self._max_size:
|
||||
return
|
||||
|
||||
if self._access_order:
|
||||
lru_key = min(self._access_order, key=lambda k: self._access_order.get(k, 0))
|
||||
del self._cache[lru_key]
|
||||
del self._access_order[lru_key]
|
||||
|
||||
def _update_access_time(self, key: str) -> None:
|
||||
"""Update access time for LRU tracking."""
|
||||
self._access_counter += 1
|
||||
self._access_order[key] = self._access_counter
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
self._update_access_time(key)
|
||||
entry.hit_count += 1
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
self._evict_if_needed()
|
||||
self._cache[key] = entry
|
||||
self._update_access_time(key)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._access_order.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
if namespace is None:
|
||||
return list(self._cache.items())
|
||||
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.namespace == namespace and not self._check_expired(v)
|
||||
]
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
candidates = [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.embedding is not None
|
||||
and not self._check_expired(v)
|
||||
and (namespace is None or v.namespace == namespace)
|
||||
]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, memory usage, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
memory_usage = sys.getsizeof(self._cache) + sum(
|
||||
sys.getsizeof(k) + sys.getsizeof(v) for k, v in self._cache.items()
|
||||
)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(self._cache),
|
||||
"memory_bytes": memory_usage,
|
||||
"max_size": self._max_size,
|
||||
}
|
||||
239
semantic_llm_cache/backends/redis.py
Normal file
239
semantic_llm_cache/backends/redis.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Redis distributed storage backend (async via redis.asyncio)."""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from redis import asyncio as aioredis
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Redis backend requires 'redis' package. "
|
||||
"Install with: pip install semantic-llm-cache[redis]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class RedisBackend(BaseBackend):
|
||||
"""Redis-based distributed cache storage (async).
|
||||
|
||||
Uses redis.asyncio (bundled with redis>=4.2) for non-blocking I/O.
|
||||
The connection is created in __init__; no explicit connect() call needed
|
||||
as redis.asyncio uses a connection pool that connects lazily.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX = "semantic_llm_cache:"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "redis://localhost:6379/0",
|
||||
prefix: str = DEFAULT_PREFIX,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize Redis backend.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL
|
||||
prefix: Key prefix for cache entries
|
||||
**kwargs: Additional arguments passed to redis.asyncio.from_url
|
||||
"""
|
||||
super().__init__()
|
||||
self._prefix = prefix.rstrip(":") + ":"
|
||||
self._redis = aioredis.from_url(url, **kwargs)
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""Test Redis connection. Call this after construction to verify connectivity.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If Redis is not reachable
|
||||
"""
|
||||
try:
|
||||
await self._redis.ping()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to connect to Redis: {e}") from e
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create full Redis key with prefix."""
|
||||
return f"{self._prefix}{key}"
|
||||
|
||||
def _entry_to_dict(self, entry: CacheEntry) -> dict[str, Any]:
|
||||
"""Convert CacheEntry to dictionary for storage."""
|
||||
return {
|
||||
"prompt": entry.prompt,
|
||||
"response": entry.response,
|
||||
"embedding": entry.embedding,
|
||||
"created_at": entry.created_at,
|
||||
"ttl": entry.ttl,
|
||||
"namespace": entry.namespace,
|
||||
"hit_count": entry.hit_count,
|
||||
"input_tokens": entry.input_tokens,
|
||||
"output_tokens": entry.output_tokens,
|
||||
}
|
||||
|
||||
def _dict_to_entry(self, data: dict[str, Any]) -> CacheEntry:
|
||||
"""Convert dictionary from storage to CacheEntry."""
|
||||
return CacheEntry(
|
||||
prompt=data["prompt"],
|
||||
response=data["response"],
|
||||
embedding=data.get("embedding"),
|
||||
created_at=data["created_at"],
|
||||
ttl=data.get("ttl"),
|
||||
namespace=data.get("namespace", "default"),
|
||||
hit_count=data.get("hit_count", 0),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = await self._redis.get(redis_key)
|
||||
|
||||
if data is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
entry_dict["hit_count"] = entry.hit_count
|
||||
await self._redis.set(redis_key, json.dumps(entry_dict))
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = json.dumps(self._entry_to_dict(entry))
|
||||
redis_ttl = entry.ttl if entry.ttl is not None else 0
|
||||
await self._redis.set(redis_key, data, ex=redis_ttl if redis_ttl > 0 else None)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await self._redis.delete(self._make_key(key))
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries with this prefix."""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
results = []
|
||||
|
||||
for full_key in keys:
|
||||
short_key = full_key.decode().replace(self._prefix, "", 1)
|
||||
data = await self._redis.get(full_key)
|
||||
|
||||
if data:
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if namespace is None or entry.namespace == namespace:
|
||||
if not self._check_expired(entry):
|
||||
results.append((short_key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Note: Loads all entries for cosine scan. For large datasets consider
|
||||
Redis Stack with vector search (RediSearch).
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics."""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(keys) if keys else 0,
|
||||
"prefix": self._prefix,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "prefix": self._prefix, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
try:
|
||||
await self._redis.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
279
semantic_llm_cache/backends/sqlite.py
Normal file
279
semantic_llm_cache/backends/sqlite.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""SQLite persistent storage backend (async via aiosqlite)."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"SQLite backend requires 'aiosqlite' package. "
|
||||
"Install with: pip install semantic-llm-cache[sqlite]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class SQLiteBackend(BaseBackend):
|
||||
"""SQLite-based persistent cache storage (async).
|
||||
|
||||
Uses aiosqlite for non-blocking I/O. A single persistent connection
|
||||
is opened lazily on first use and reused for all subsequent operations.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | Path = "semantic_cache.db") -> None:
|
||||
"""Initialize SQLite backend.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file, or ":memory:" for in-memory DB
|
||||
"""
|
||||
super().__init__()
|
||||
self._db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||
self._conn: Optional[aiosqlite.Connection] = None
|
||||
|
||||
async def _get_conn(self) -> aiosqlite.Connection:
|
||||
"""Get or create the persistent async connection."""
|
||||
if self._conn is None:
|
||||
self._conn = await aiosqlite.connect(self._db_path)
|
||||
self._conn.row_factory = aiosqlite.Row
|
||||
await self._initialize_schema()
|
||||
return self._conn
|
||||
|
||||
async def _initialize_schema(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
conn = await self._get_conn()
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS cache_entries (
|
||||
key TEXT PRIMARY KEY,
|
||||
prompt TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
ttl INTEGER,
|
||||
namespace TEXT NOT NULL DEFAULT 'default',
|
||||
hit_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_namespace
|
||||
ON cache_entries(namespace)
|
||||
"""
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
def _row_to_entry(self, row: aiosqlite.Row) -> CacheEntry:
|
||||
"""Convert database row to CacheEntry."""
|
||||
embedding = None
|
||||
if row["embedding"]:
|
||||
embedding = json.loads(row["embedding"])
|
||||
|
||||
return CacheEntry(
|
||||
prompt=row["prompt"],
|
||||
response=json.loads(row["response"]),
|
||||
embedding=embedding,
|
||||
created_at=row["created_at"],
|
||||
ttl=row["ttl"],
|
||||
namespace=row["namespace"],
|
||||
hit_count=row["hit_count"],
|
||||
input_tokens=row["input_tokens"],
|
||||
output_tokens=row["output_tokens"],
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"SELECT * FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry = self._row_to_entry(row)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
await conn.execute(
|
||||
"UPDATE cache_entries SET hit_count = hit_count + 1 WHERE key = ?",
|
||||
(key,),
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
embedding_json = json.dumps(entry.embedding) if entry.embedding else None
|
||||
response_json = json.dumps(entry.response)
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO cache_entries
|
||||
(key, prompt, response, embedding, created_at, ttl, namespace,
|
||||
hit_count, input_tokens, output_tokens)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
key,
|
||||
entry.prompt,
|
||||
response_json,
|
||||
embedding_json,
|
||||
entry.created_at,
|
||||
entry.ttl,
|
||||
entry.namespace,
|
||||
entry.hit_count,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
),
|
||||
)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"DELETE FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
await conn.commit()
|
||||
return rowcount > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
await conn.execute("DELETE FROM cache_entries")
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
|
||||
if namespace is None:
|
||||
query = "SELECT key, * FROM cache_entries"
|
||||
params: tuple[()] = ()
|
||||
else:
|
||||
query = "SELECT key, * FROM cache_entries WHERE namespace = ?"
|
||||
params = (namespace,)
|
||||
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
key = row["key"]
|
||||
entry = self._row_to_entry(row)
|
||||
if not self._check_expired(entry):
|
||||
results.append((key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, database path, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
size = row[0] if row else 0
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": size,
|
||||
"db_path": self._db_path,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "db_path": self._db_path, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._conn is not None:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
61
semantic_llm_cache/config.py
Normal file
61
semantic_llm_cache/config.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Configuration management for prompt-cache."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for cache behavior."""
|
||||
|
||||
similarity_threshold: float = 1.0 # 1.0 = exact match, lower = semantic
|
||||
ttl: Optional[int] = 3600 # Time to live in seconds, None = forever
|
||||
namespace: str = "default" # Isolate different use cases
|
||||
enabled: bool = True # Enable/disable caching
|
||||
key_func: Optional[Callable[[Any], str]] = None # Custom cache key function
|
||||
|
||||
# Cost estimation for statistics (USD per 1K tokens)
|
||||
input_cost_per_1k: float = 0.001 # Default ~$1/1M for cheaper models
|
||||
output_cost_per_1k: float = 0.002 # Default ~$2/1M for cheaper models
|
||||
|
||||
# Performance settings
|
||||
max_cache_size: Optional[int] = None # LRU eviction when set
|
||||
embedding_model: str = "all-MiniLM-L6-v2" # Default sentence-transformer model
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration."""
|
||||
if not 0.0 <= self.similarity_threshold <= 1.0:
|
||||
raise ValueError("similarity_threshold must be between 0.0 and 1.0")
|
||||
if self.ttl is not None and self.ttl <= 0:
|
||||
raise ValueError("ttl must be positive or None")
|
||||
if self.max_cache_size is not None and self.max_cache_size <= 0:
|
||||
raise ValueError("max_cache_size must be positive or None")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached response with metadata."""
|
||||
|
||||
prompt: str
|
||||
response: Any
|
||||
embedding: Optional[list[float]] = None # Normalized embedding vector
|
||||
created_at: float = 0.0 # Unix timestamp
|
||||
ttl: Optional[int] = None # Time to live in seconds
|
||||
namespace: str = "default"
|
||||
hit_count: int = 0
|
||||
|
||||
# Approximate token counts for cost estimation
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
def is_expired(self, current_time: float) -> bool:
|
||||
"""Check if entry has expired based on TTL."""
|
||||
if self.ttl is None:
|
||||
return False
|
||||
return (current_time - self.created_at) > self.ttl
|
||||
|
||||
def estimate_cost(self, input_cost: float, output_cost: float) -> float:
|
||||
"""Estimate cost savings in USD."""
|
||||
input_savings = (self.input_tokens / 1000) * input_cost
|
||||
output_savings = (self.output_tokens / 1000) * output_cost
|
||||
return input_savings + output_savings
|
||||
369
semantic_llm_cache/core.py
Normal file
369
semantic_llm_cache/core.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
"""Core cache decorator and API for llm-semantic-cache."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheConfig, CacheEntry
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
from semantic_llm_cache.similarity import EmbeddingCache
|
||||
from semantic_llm_cache.stats import _stats_manager
|
||||
from semantic_llm_cache.utils import hash_prompt, normalize_prompt
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _extract_prompt(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
"""Extract prompt string from function arguments."""
|
||||
if args and isinstance(args[0], str):
|
||||
return args[0]
|
||||
if "prompt" in kwargs:
|
||||
return str(kwargs["prompt"])
|
||||
return str(args) + str(sorted(kwargs.items()))
|
||||
|
||||
|
||||
class CacheContext:
|
||||
"""Context manager for cache configuration.
|
||||
|
||||
Supports both sync (with) and async (async with) usage.
|
||||
|
||||
Examples:
|
||||
>>> async with CacheContext(similarity=0.9) as ctx:
|
||||
... result = await llm_call("prompt")
|
||||
... print(ctx.stats)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity: Optional[float] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity if similarity is not None else 1.0,
|
||||
ttl=ttl,
|
||||
namespace=namespace if namespace is not None else "default",
|
||||
enabled=enabled if enabled is not None else True,
|
||||
)
|
||||
self._stats: dict[str, Any] = {"hits": 0, "misses": 0}
|
||||
|
||||
def __enter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def stats(self) -> dict[str, Any]:
|
||||
return self._stats.copy()
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
return self._config
|
||||
|
||||
|
||||
class CachedLLM:
|
||||
"""Wrapper class for LLM calls with automatic caching.
|
||||
|
||||
Examples:
|
||||
>>> llm = CachedLLM(similarity=0.9)
|
||||
>>> response = await llm.achat("What is Python?", llm_func=my_async_llm)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: str = "gpt-4",
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._backend = backend or MemoryBackend()
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity,
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_func: Optional[Callable[[str], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Get response with caching (async).
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
llm_func: Async or sync LLM function to call on cache miss
|
||||
**kwargs: Additional arguments for llm_func
|
||||
|
||||
Returns:
|
||||
LLM response (cached or fresh)
|
||||
"""
|
||||
if llm_func is None:
|
||||
raise ValueError("llm_func is required for CachedLLM.achat()")
|
||||
|
||||
@cache(
|
||||
similarity=self._config.similarity_threshold,
|
||||
ttl=self._config.ttl,
|
||||
backend=self._backend,
|
||||
namespace=self._config.namespace,
|
||||
enabled=self._config.enabled,
|
||||
)
|
||||
async def _cached_call(p: str) -> Any:
|
||||
result = llm_func(p, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
return await _cached_call(prompt)
|
||||
|
||||
|
||||
def cache(
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator for caching LLM function responses.
|
||||
|
||||
Auto-detects whether the decorated function is async or sync and returns
|
||||
the appropriate wrapper. Both variants share identical cache logic.
|
||||
|
||||
Async functions get a true async wrapper (awaits all backend calls).
|
||||
Sync functions get a sync wrapper that drives the async backends via a
|
||||
temporary event loop — not suitable inside a running loop; prefer decorating
|
||||
async functions when integrating with async frameworks like FastAPI.
|
||||
|
||||
Args:
|
||||
similarity: Cosine similarity threshold (1.0=exact, 0.9=semantic)
|
||||
ttl: Time-to-live in seconds (None=forever)
|
||||
backend: Async storage backend (None=in-memory)
|
||||
namespace: Cache namespace for isolation
|
||||
enabled: Whether caching is enabled
|
||||
key_func: Custom cache key function
|
||||
|
||||
Returns:
|
||||
Decorated function with caching
|
||||
|
||||
Examples:
|
||||
>>> @cache(similarity=0.9, ttl=3600)
|
||||
... async def ask_llm(prompt: str) -> str:
|
||||
... return await call_ollama(prompt)
|
||||
|
||||
>>> @cache()
|
||||
... def ask_llm_sync(prompt: str) -> str:
|
||||
... return call_ollama_sync(prompt)
|
||||
"""
|
||||
_backend = backend or MemoryBackend()
|
||||
embedding_cache = EmbeddingCache()
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
# ── Async wrapper ────────────────────────────────────────────────
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
# 1. Exact match
|
||||
entry = await _backend.get(cache_key)
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = await embedding_cache.aencode(normalized)
|
||||
result = await _backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss — call through
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = await embedding_cache.aencode(normalized)
|
||||
|
||||
await _backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
# ── Sync wrapper (backwards compatibility) ───────────────────────
|
||||
# Drives async backends via a dedicated event loop per call.
|
||||
# Do NOT use inside a running event loop (e.g. FastAPI handlers).
|
||||
import asyncio
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# 1. Exact match
|
||||
entry = loop.run_until_complete(_backend.get(cache_key))
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = embedding_cache.encode(normalized)
|
||||
result = loop.run_until_complete(
|
||||
_backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = embedding_cache.encode(normalized)
|
||||
|
||||
loop.run_until_complete(
|
||||
_backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global default backend for utility functions
|
||||
_default_backend: Optional[BaseBackend] = None
|
||||
|
||||
|
||||
def get_default_backend() -> BaseBackend:
|
||||
"""Get default storage backend."""
|
||||
global _default_backend
|
||||
if _default_backend is None:
|
||||
_default_backend = MemoryBackend()
|
||||
return _default_backend
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend) -> None:
|
||||
"""Set default storage backend."""
|
||||
global _default_backend
|
||||
_default_backend = backend
|
||||
_stats_manager.set_backend(backend)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cache",
|
||||
"CacheContext",
|
||||
"CachedLLM",
|
||||
"get_default_backend",
|
||||
"set_default_backend",
|
||||
]
|
||||
25
semantic_llm_cache/exceptions.py
Normal file
25
semantic_llm_cache/exceptions.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Custom exceptions for prompt-cache."""
|
||||
|
||||
|
||||
class PromptCacheError(Exception):
|
||||
"""Base exception for prompt-cache errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheBackendError(PromptCacheError):
|
||||
"""Exception raised when backend operations fail."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheSerializationError(PromptCacheError):
|
||||
"""Exception raised when serialization/deserialization fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CacheNotFoundError(PromptCacheError):
|
||||
"""Exception raised when cache entry is not found."""
|
||||
|
||||
pass
|
||||
1
semantic_llm_cache/py.typed
Normal file
1
semantic_llm_cache/py.typed
Normal file
|
|
@ -0,0 +1 @@
|
|||
# PEP 561 marker file for type hints
|
||||
283
semantic_llm_cache/similarity.py
Normal file
283
semantic_llm_cache/similarity.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Embedding generation and similarity matching for llm-semantic-cache."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyEmbeddingProvider(EmbeddingProvider):
|
||||
"""Fallback embedding provider using hash-based vectors.
|
||||
|
||||
Provides consistent embeddings without external dependencies.
|
||||
Not semantically meaningful but provides consistent cache keys.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 384) -> None:
|
||||
"""Initialize dummy provider.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (matches MiniLM default)
|
||||
"""
|
||||
self._dim = dim
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate hash-based embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Deterministic embedding vector based on text hash
|
||||
"""
|
||||
hash_obj = hashlib.sha256(text.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
if len(values) < self._dim:
|
||||
values = np.pad(values, (0, self._dim - len(values)))
|
||||
|
||||
norm = np.linalg.norm(values)
|
||||
if norm > 0:
|
||||
values = values / norm
|
||||
|
||||
return values.tolist()
|
||||
|
||||
|
||||
class SentenceTransformerProvider(EmbeddingProvider):
|
||||
"""Sentence-transformers based embedding provider.
|
||||
|
||||
Uses local models like MiniLM for semantic embeddings.
|
||||
Inference is CPU/GPU-bound; use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
|
||||
"""Initialize sentence-transformer provider.
|
||||
|
||||
Args:
|
||||
model_name: Name of sentence-transformer model
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"sentence-transformers package required for semantic matching. "
|
||||
"Install with: pip install semantic-llm-cache[semantic]"
|
||||
) from e
|
||||
|
||||
self._model = SentenceTransformer(model_name)
|
||||
self._dim = self._model.get_sentence_embedding_dimension()
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Normalized embedding vector
|
||||
"""
|
||||
embedding = self._model.encode(text, convert_to_numpy=True)
|
||||
embedding = np.asarray(embedding, dtype=np.float32)
|
||||
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI API-based embedding provider.
|
||||
|
||||
Uses OpenAI's embedding API for high-quality semantic embeddings.
|
||||
Network I/O — always use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, model: str = "text-embedding-3-small"
|
||||
) -> None:
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key (uses OPENAI_API_KEY env var if None)
|
||||
model: OpenAI embedding model to use
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"openai package required for OpenAI embeddings. "
|
||||
"Install with: pip install semantic-llm-cache[openai]"
|
||||
) from e
|
||||
|
||||
self._client = openai.OpenAI(api_key=api_key)
|
||||
self._model = model
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
OpenAI embedding vector (already normalized)
|
||||
"""
|
||||
response = self._client.embeddings.create(input=text, model=self._model)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
embedding_arr = np.asarray(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(embedding_arr)
|
||||
if norm > 0:
|
||||
embedding_arr = embedding_arr / norm
|
||||
|
||||
return embedding_arr.tolist()
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If vectors have different dimensions
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
if a_arr.shape != b_arr.shape:
|
||||
raise ValueError(
|
||||
f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}"
|
||||
)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]:
|
||||
"""Helper function for LRU cache encoding.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
provider: Embedding provider
|
||||
|
||||
Returns:
|
||||
Embedding as tuple for hashability
|
||||
"""
|
||||
return tuple(provider.encode(text))
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embedding generation with LRU eviction.
|
||||
|
||||
Use encode() from sync contexts, aencode() from async contexts.
|
||||
aencode() offloads blocking inference to a thread pool via asyncio.to_thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: Optional[EmbeddingProvider] = None,
|
||||
cache_size: int = 1024,
|
||||
) -> None:
|
||||
"""Initialize embedding cache.
|
||||
|
||||
Args:
|
||||
provider: Embedding provider (uses DummyEmbeddingProvider if None)
|
||||
cache_size: Maximum number of embeddings to cache
|
||||
"""
|
||||
self._provider = provider or DummyEmbeddingProvider()
|
||||
self._cache_size = cache_size
|
||||
self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider)
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (sync, blocking).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return list(self._get_cached(text, self._provider))
|
||||
|
||||
async def aencode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (async, non-blocking).
|
||||
|
||||
CPU/network-bound work is offloaded to the default thread pool via
|
||||
asyncio.to_thread, keeping the event loop free.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return await asyncio.to_thread(self.encode, text)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the embedding LRU cache."""
|
||||
self._get_cached.cache_clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "auto",
|
||||
model_name: Optional[str] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create embedding provider based on type.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy")
|
||||
model_name: Optional model name to use
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider_type == "auto":
|
||||
try:
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
except PromptCacheError:
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
if provider_type == "sentence-transformer":
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(model=model_name)
|
||||
|
||||
if provider_type == "dummy":
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
255
semantic_llm_cache/stats.py
Normal file
255
semantic_llm_cache/stats.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""Statistics and analytics for llm-semantic-cache."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics for cache performance."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
total_saved_ms: float = 0.0
|
||||
estimated_savings_usd: float = 0.0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
return self.hits / max(total, 1)
|
||||
|
||||
@property
|
||||
def total_requests(self) -> int:
|
||||
"""Get total number of requests."""
|
||||
return self.hits + self.misses
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"total_requests": self.total_requests,
|
||||
"total_saved_ms": self.total_saved_ms,
|
||||
"estimated_savings_usd": self.estimated_savings_usd,
|
||||
}
|
||||
|
||||
def __iadd__(self, other: "CacheStats") -> "CacheStats":
|
||||
self.hits += other.hits
|
||||
self.misses += other.misses
|
||||
self.total_saved_ms += other.total_saved_ms
|
||||
self.estimated_savings_usd += other.estimated_savings_usd
|
||||
return self
|
||||
|
||||
|
||||
class _StatsManager:
|
||||
"""Manager for global cache statistics.
|
||||
|
||||
Uses threading.Lock for record_hit/record_miss — these are simple counter
|
||||
increments with no awaits inside the lock, so threading.Lock is safe and
|
||||
avoids the overhead of asyncio.Lock for hot-path calls.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize stats manager."""
|
||||
self._stats: dict[str, CacheStats] = {}
|
||||
self._lock = Lock()
|
||||
self._default_backend: Optional[BaseBackend] = None
|
||||
|
||||
def get_backend(self) -> BaseBackend:
|
||||
"""Get default backend for cache operations."""
|
||||
if self._default_backend is None:
|
||||
self._default_backend = MemoryBackend()
|
||||
return self._default_backend
|
||||
|
||||
def set_backend(self, backend: BaseBackend) -> None:
|
||||
"""Set default backend for cache operations."""
|
||||
with self._lock:
|
||||
self._default_backend = backend
|
||||
|
||||
def record_hit(
|
||||
self,
|
||||
namespace: str,
|
||||
latency_saved_ms: float = 0.0,
|
||||
saved_cost: float = 0.0,
|
||||
) -> None:
|
||||
"""Record a cache hit (sync, safe to call from async context)."""
|
||||
with self._lock:
|
||||
if namespace not in self._stats:
|
||||
self._stats[namespace] = CacheStats()
|
||||
stats = self._stats[namespace]
|
||||
stats.hits += 1
|
||||
stats.total_saved_ms += latency_saved_ms
|
||||
stats.estimated_savings_usd += saved_cost
|
||||
|
||||
def record_miss(self, namespace: str) -> None:
|
||||
"""Record a cache miss (sync, safe to call from async context)."""
|
||||
with self._lock:
|
||||
if namespace not in self._stats:
|
||||
self._stats[namespace] = CacheStats()
|
||||
self._stats[namespace].misses += 1
|
||||
|
||||
def get_stats(self, namespace: Optional[str] = None) -> CacheStats:
|
||||
"""Get statistics for namespace or all."""
|
||||
with self._lock:
|
||||
if namespace is not None:
|
||||
return self._stats.get(namespace, CacheStats())
|
||||
|
||||
total = CacheStats()
|
||||
for stats in self._stats.values():
|
||||
total += stats
|
||||
return total
|
||||
|
||||
def clear_stats(self, namespace: Optional[str] = None) -> None:
|
||||
"""Clear statistics for namespace or all."""
|
||||
with self._lock:
|
||||
if namespace is None:
|
||||
self._stats.clear()
|
||||
elif namespace in self._stats:
|
||||
del self._stats[namespace]
|
||||
|
||||
|
||||
# Global stats manager instance
|
||||
_stats_manager = _StatsManager()
|
||||
|
||||
|
||||
def get_stats(namespace: Optional[str] = None) -> dict[str, Any]:
|
||||
"""Get cache statistics (sync).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace to filter by
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
return _stats_manager.get_stats(namespace).to_dict()
|
||||
|
||||
|
||||
async def clear_cache(namespace: Optional[str] = None) -> int:
|
||||
"""Clear all cached entries (async).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace to clear (None = all)
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
backend = _stats_manager.get_backend()
|
||||
|
||||
if namespace is None:
|
||||
stats = await backend.get_stats()
|
||||
size = stats.get("size", 0)
|
||||
await backend.clear()
|
||||
_stats_manager.clear_stats()
|
||||
return size
|
||||
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
count = len(entries)
|
||||
for key, _ in entries:
|
||||
await backend.delete(key)
|
||||
_stats_manager.clear_stats(namespace)
|
||||
return count
|
||||
|
||||
|
||||
async def invalidate(
|
||||
pattern: str,
|
||||
namespace: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Invalidate cache entries matching pattern (async).
|
||||
|
||||
Args:
|
||||
pattern: String pattern to match in prompts
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
backend = _stats_manager.get_backend()
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
count = 0
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
for key, entry in entries:
|
||||
if pattern_lower in entry.prompt.lower():
|
||||
await backend.delete(key)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def warm_cache(
|
||||
prompts: list[str],
|
||||
llm_func: Callable[[str], Any],
|
||||
namespace: str = "default",
|
||||
) -> int:
|
||||
"""Pre-populate cache with prompts (async).
|
||||
|
||||
Args:
|
||||
prompts: List of prompts to cache
|
||||
llm_func: Async or sync LLM function to call for each prompt
|
||||
namespace: Cache namespace to use
|
||||
|
||||
Returns:
|
||||
Number of prompts attempted
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
from semantic_llm_cache.core import cache
|
||||
|
||||
cached_func = cache(namespace=namespace)(llm_func)
|
||||
|
||||
for prompt in prompts:
|
||||
try:
|
||||
result = cached_func(prompt)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return len(prompts)
|
||||
|
||||
|
||||
async def export_cache(
|
||||
namespace: Optional[str] = None,
|
||||
filepath: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Export cache entries for analysis (async).
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
filepath: Optional file path to save export (JSON)
|
||||
|
||||
Returns:
|
||||
List of cache entry dictionaries
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
backend = _stats_manager.get_backend()
|
||||
entries = await backend.iterate(namespace=namespace)
|
||||
|
||||
export_data = []
|
||||
for key, entry in entries:
|
||||
export_data.append({
|
||||
"key": key,
|
||||
"prompt": entry.prompt,
|
||||
"response": str(entry.response)[:1000],
|
||||
"namespace": entry.namespace,
|
||||
"hit_count": entry.hit_count,
|
||||
"created_at": datetime.fromtimestamp(entry.created_at).isoformat(),
|
||||
"ttl": entry.ttl,
|
||||
"input_tokens": entry.input_tokens,
|
||||
"output_tokens": entry.output_tokens,
|
||||
})
|
||||
|
||||
if filepath:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
|
||||
return export_data
|
||||
111
semantic_llm_cache/storage.py
Normal file
111
semantic_llm_cache/storage.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Storage backend interface for prompt-cache."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Abstract base class for async cache storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def iterate(self, namespace: Optional[str] = None) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with stats like size, memory_usage, etc.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If backend operation fails
|
||||
"""
|
||||
pass
|
||||
97
semantic_llm_cache/utils/__init__.py
Normal file
97
semantic_llm_cache/utils/__init__.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Utility functions for prompt-cache."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_prompt(prompt: str) -> str:
|
||||
"""Normalize prompt text for consistent caching.
|
||||
|
||||
Args:
|
||||
prompt: Raw prompt text
|
||||
|
||||
Returns:
|
||||
Normalized prompt text
|
||||
"""
|
||||
# Remove extra whitespace
|
||||
prompt = " ".join(prompt.split())
|
||||
|
||||
# Lowercase for better matching (optional - can affect semantics)
|
||||
# prompt = prompt.lower()
|
||||
|
||||
# Remove common filler words at start
|
||||
filler_pattern = r"^(please|can you|could you|i need|i want)\s+"
|
||||
prompt = re.sub(filler_pattern, "", prompt, flags=re.IGNORECASE)
|
||||
|
||||
# Normalize quotes
|
||||
prompt = prompt.replace('"', "'").replace("`", "'")
|
||||
|
||||
# Remove trailing punctuation
|
||||
prompt = prompt.rstrip("?!.")
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
def hash_prompt(prompt: str, namespace: str = "default") -> str:
|
||||
"""Generate cache key from prompt and namespace.
|
||||
|
||||
Args:
|
||||
prompt: Prompt text
|
||||
namespace: Cache namespace
|
||||
|
||||
Returns:
|
||||
Hash-based cache key
|
||||
"""
|
||||
combined = f"{namespace}:{prompt}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough approximation: ~4 chars per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def serialize_response(response: Any) -> str:
|
||||
"""Serialize response for storage.
|
||||
|
||||
Args:
|
||||
response: Response object (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
Serialized JSON string
|
||||
"""
|
||||
import json
|
||||
|
||||
return json.dumps(response)
|
||||
|
||||
|
||||
def deserialize_response(data: str) -> Any:
|
||||
"""Deserialize response from storage.
|
||||
|
||||
Args:
|
||||
data: Serialized JSON string
|
||||
|
||||
Returns:
|
||||
Deserialized response object
|
||||
"""
|
||||
import json
|
||||
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"normalize_prompt",
|
||||
"hash_prompt",
|
||||
"estimate_tokens",
|
||||
"serialize_response",
|
||||
"deserialize_response",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue