Add files via upload
initial commit
This commit is contained in:
parent
8d3d5ff628
commit
b33bb415dd
24 changed files with 4840 additions and 0 deletions
21
semantic_llm_cache/backends/__init__.py
Normal file
21
semantic_llm_cache/backends/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Storage backends for llm-semantic-cache."""
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.backends.memory import MemoryBackend
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.sqlite import SQLiteBackend
|
||||
except ImportError:
|
||||
SQLiteBackend = None # type: ignore
|
||||
|
||||
try:
|
||||
from semantic_llm_cache.backends.redis import RedisBackend
|
||||
except ImportError:
|
||||
RedisBackend = None # type: ignore
|
||||
|
||||
__all__ = [
|
||||
"BaseBackend",
|
||||
"MemoryBackend",
|
||||
"SQLiteBackend",
|
||||
"RedisBackend",
|
||||
]
|
||||
104
semantic_llm_cache/backends/base.py
Normal file
104
semantic_llm_cache/backends/base.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Base backend implementation with common functionality."""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.storage import StorageBackend
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
class BaseBackend(StorageBackend):
|
||||
"""Base backend with common sync helpers; async public interface via StorageBackend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize base backend."""
|
||||
self._hits: int = 0
|
||||
self._misses: int = 0
|
||||
|
||||
def _increment_hits(self) -> None:
|
||||
"""Increment hit counter."""
|
||||
self._hits += 1
|
||||
|
||||
def _increment_misses(self) -> None:
|
||||
"""Increment miss counter."""
|
||||
self._misses += 1
|
||||
|
||||
def _check_expired(self, entry: CacheEntry) -> bool:
|
||||
"""Check if entry is expired.
|
||||
|
||||
Args:
|
||||
entry: CacheEntry to check
|
||||
|
||||
Returns:
|
||||
True if expired, False otherwise
|
||||
"""
|
||||
return entry.is_expired(time.time())
|
||||
|
||||
def _find_best_match(
|
||||
self,
|
||||
candidates: list[tuple[str, CacheEntry]],
|
||||
query_embedding: list[float],
|
||||
threshold: float,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find best matching entry from candidates.
|
||||
|
||||
Sync helper — CPU-only numpy ops, safe to call from async context.
|
||||
|
||||
Args:
|
||||
candidates: List of (key, entry) tuples
|
||||
query_embedding: Query embedding vector
|
||||
threshold: Minimum similarity threshold
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
best_match: Optional[tuple[str, CacheEntry, float]] = None
|
||||
best_similarity = threshold
|
||||
|
||||
for key, entry in candidates:
|
||||
if entry.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(query_embedding, entry.embedding)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = (key, entry, similarity)
|
||||
|
||||
return best_match
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with hits and misses
|
||||
"""
|
||||
return {
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": self._hits / max(self._hits + self._misses, 1),
|
||||
}
|
||||
179
semantic_llm_cache/backends/memory.py
Normal file
179
semantic_llm_cache/backends/memory.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""In-memory storage backend."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class MemoryBackend(BaseBackend):
|
||||
"""In-memory cache storage with LRU eviction.
|
||||
|
||||
All operations are in-memory dict access — no I/O — so async methods
|
||||
run directly in the event loop without thread offloading.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: Optional[int] = None) -> None:
|
||||
"""Initialize memory backend.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store (LRU eviction when reached)
|
||||
"""
|
||||
super().__init__()
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._access_order: dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._access_counter: float = 0.0
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entry if at capacity."""
|
||||
if self._max_size is None or len(self._cache) < self._max_size:
|
||||
return
|
||||
|
||||
if self._access_order:
|
||||
lru_key = min(self._access_order, key=lambda k: self._access_order.get(k, 0))
|
||||
del self._cache[lru_key]
|
||||
del self._access_order[lru_key]
|
||||
|
||||
def _update_access_time(self, key: str) -> None:
|
||||
"""Update access time for LRU tracking."""
|
||||
self._access_counter += 1
|
||||
self._access_order[key] = self._access_counter
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
self._update_access_time(key)
|
||||
entry.hit_count += 1
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
self._evict_if_needed()
|
||||
self._cache[key] = entry
|
||||
self._update_access_time(key)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._access_order.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
if namespace is None:
|
||||
return list(self._cache.items())
|
||||
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.namespace == namespace and not self._check_expired(v)
|
||||
]
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
candidates = [
|
||||
(k, v)
|
||||
for k, v in self._cache.items()
|
||||
if v.embedding is not None
|
||||
and not self._check_expired(v)
|
||||
and (namespace is None or v.namespace == namespace)
|
||||
]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, memory usage, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
memory_usage = sys.getsizeof(self._cache) + sum(
|
||||
sys.getsizeof(k) + sys.getsizeof(v) for k, v in self._cache.items()
|
||||
)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(self._cache),
|
||||
"memory_bytes": memory_usage,
|
||||
"max_size": self._max_size,
|
||||
}
|
||||
239
semantic_llm_cache/backends/redis.py
Normal file
239
semantic_llm_cache/backends/redis.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Redis distributed storage backend (async via redis.asyncio)."""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from redis import asyncio as aioredis
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Redis backend requires 'redis' package. "
|
||||
"Install with: pip install semantic-llm-cache[redis]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class RedisBackend(BaseBackend):
|
||||
"""Redis-based distributed cache storage (async).
|
||||
|
||||
Uses redis.asyncio (bundled with redis>=4.2) for non-blocking I/O.
|
||||
The connection is created in __init__; no explicit connect() call needed
|
||||
as redis.asyncio uses a connection pool that connects lazily.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX = "semantic_llm_cache:"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "redis://localhost:6379/0",
|
||||
prefix: str = DEFAULT_PREFIX,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize Redis backend.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL
|
||||
prefix: Key prefix for cache entries
|
||||
**kwargs: Additional arguments passed to redis.asyncio.from_url
|
||||
"""
|
||||
super().__init__()
|
||||
self._prefix = prefix.rstrip(":") + ":"
|
||||
self._redis = aioredis.from_url(url, **kwargs)
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""Test Redis connection. Call this after construction to verify connectivity.
|
||||
|
||||
Raises:
|
||||
CacheBackendError: If Redis is not reachable
|
||||
"""
|
||||
try:
|
||||
await self._redis.ping()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to connect to Redis: {e}") from e
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create full Redis key with prefix."""
|
||||
return f"{self._prefix}{key}"
|
||||
|
||||
def _entry_to_dict(self, entry: CacheEntry) -> dict[str, Any]:
|
||||
"""Convert CacheEntry to dictionary for storage."""
|
||||
return {
|
||||
"prompt": entry.prompt,
|
||||
"response": entry.response,
|
||||
"embedding": entry.embedding,
|
||||
"created_at": entry.created_at,
|
||||
"ttl": entry.ttl,
|
||||
"namespace": entry.namespace,
|
||||
"hit_count": entry.hit_count,
|
||||
"input_tokens": entry.input_tokens,
|
||||
"output_tokens": entry.output_tokens,
|
||||
}
|
||||
|
||||
def _dict_to_entry(self, data: dict[str, Any]) -> CacheEntry:
|
||||
"""Convert dictionary from storage to CacheEntry."""
|
||||
return CacheEntry(
|
||||
prompt=data["prompt"],
|
||||
response=data["response"],
|
||||
embedding=data.get("embedding"),
|
||||
created_at=data["created_at"],
|
||||
ttl=data.get("ttl"),
|
||||
namespace=data.get("namespace", "default"),
|
||||
hit_count=data.get("hit_count", 0),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = await self._redis.get(redis_key)
|
||||
|
||||
if data is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
entry_dict["hit_count"] = entry.hit_count
|
||||
await self._redis.set(redis_key, json.dumps(entry_dict))
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
data = json.dumps(self._entry_to_dict(entry))
|
||||
redis_ttl = entry.ttl if entry.ttl is not None else 0
|
||||
await self._redis.set(redis_key, data, ex=redis_ttl if redis_ttl > 0 else None)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await self._redis.delete(self._make_key(key))
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries with this prefix."""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
results = []
|
||||
|
||||
for full_key in keys:
|
||||
short_key = full_key.decode().replace(self._prefix, "", 1)
|
||||
data = await self._redis.get(full_key)
|
||||
|
||||
if data:
|
||||
entry_dict = json.loads(data)
|
||||
entry = self._dict_to_entry(entry_dict)
|
||||
|
||||
if namespace is None or entry.namespace == namespace:
|
||||
if not self._check_expired(entry):
|
||||
results.append((short_key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Note: Loads all entries for cosine scan. For large datasets consider
|
||||
Redis Stack with vector search (RediSearch).
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics."""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
keys = await self._redis.keys(f"{self._prefix}*")
|
||||
return {
|
||||
**base_stats,
|
||||
"size": len(keys) if keys else 0,
|
||||
"prefix": self._prefix,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "prefix": self._prefix, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
try:
|
||||
await self._redis.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
279
semantic_llm_cache/backends/sqlite.py
Normal file
279
semantic_llm_cache/backends/sqlite.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""SQLite persistent storage backend (async via aiosqlite)."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"SQLite backend requires 'aiosqlite' package. "
|
||||
"Install with: pip install semantic-llm-cache[sqlite]"
|
||||
) from err
|
||||
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheEntry
|
||||
from semantic_llm_cache.exceptions import CacheBackendError
|
||||
|
||||
|
||||
class SQLiteBackend(BaseBackend):
|
||||
"""SQLite-based persistent cache storage (async).
|
||||
|
||||
Uses aiosqlite for non-blocking I/O. A single persistent connection
|
||||
is opened lazily on first use and reused for all subsequent operations.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | Path = "semantic_cache.db") -> None:
|
||||
"""Initialize SQLite backend.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file, or ":memory:" for in-memory DB
|
||||
"""
|
||||
super().__init__()
|
||||
self._db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||
self._conn: Optional[aiosqlite.Connection] = None
|
||||
|
||||
async def _get_conn(self) -> aiosqlite.Connection:
|
||||
"""Get or create the persistent async connection."""
|
||||
if self._conn is None:
|
||||
self._conn = await aiosqlite.connect(self._db_path)
|
||||
self._conn.row_factory = aiosqlite.Row
|
||||
await self._initialize_schema()
|
||||
return self._conn
|
||||
|
||||
async def _initialize_schema(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
conn = await self._get_conn()
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS cache_entries (
|
||||
key TEXT PRIMARY KEY,
|
||||
prompt TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
ttl INTEGER,
|
||||
namespace TEXT NOT NULL DEFAULT 'default',
|
||||
hit_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_namespace
|
||||
ON cache_entries(namespace)
|
||||
"""
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
def _row_to_entry(self, row: aiosqlite.Row) -> CacheEntry:
|
||||
"""Convert database row to CacheEntry."""
|
||||
embedding = None
|
||||
if row["embedding"]:
|
||||
embedding = json.loads(row["embedding"])
|
||||
|
||||
return CacheEntry(
|
||||
prompt=row["prompt"],
|
||||
response=json.loads(row["response"]),
|
||||
embedding=embedding,
|
||||
created_at=row["created_at"],
|
||||
ttl=row["ttl"],
|
||||
namespace=row["namespace"],
|
||||
hit_count=row["hit_count"],
|
||||
input_tokens=row["input_tokens"],
|
||||
output_tokens=row["output_tokens"],
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Retrieve cache entry by key.
|
||||
|
||||
Args:
|
||||
key: Cache key to retrieve
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and not expired, None otherwise
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"SELECT * FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
entry = self._row_to_entry(row)
|
||||
|
||||
if self._check_expired(entry):
|
||||
await self.delete(key)
|
||||
self._increment_misses()
|
||||
return None
|
||||
|
||||
self._increment_hits()
|
||||
entry.hit_count += 1
|
||||
|
||||
await conn.execute(
|
||||
"UPDATE cache_entries SET hit_count = hit_count + 1 WHERE key = ?",
|
||||
(key,),
|
||||
)
|
||||
await conn.commit()
|
||||
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to get entry: {e}") from e
|
||||
|
||||
async def set(self, key: str, entry: CacheEntry) -> None:
|
||||
"""Store cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to store under
|
||||
entry: CacheEntry to store
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
embedding_json = json.dumps(entry.embedding) if entry.embedding else None
|
||||
response_json = json.dumps(entry.response)
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO cache_entries
|
||||
(key, prompt, response, embedding, created_at, ttl, namespace,
|
||||
hit_count, input_tokens, output_tokens)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
key,
|
||||
entry.prompt,
|
||||
response_json,
|
||||
embedding_json,
|
||||
entry.created_at,
|
||||
entry.ttl,
|
||||
entry.namespace,
|
||||
entry.hit_count,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
),
|
||||
)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to set entry: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute(
|
||||
"DELETE FROM cache_entries WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
rowcount = cursor.rowcount
|
||||
await conn.commit()
|
||||
return rowcount > 0
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to delete entry: {e}") from e
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
await conn.execute("DELETE FROM cache_entries")
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to clear cache: {e}") from e
|
||||
|
||||
async def iterate(
|
||||
self, namespace: Optional[str] = None
|
||||
) -> list[tuple[str, CacheEntry]]:
|
||||
"""Iterate over cache entries, optionally filtered by namespace.
|
||||
|
||||
Args:
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
List of (key, entry) tuples
|
||||
"""
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
|
||||
if namespace is None:
|
||||
query = "SELECT key, * FROM cache_entries"
|
||||
params: tuple[()] = ()
|
||||
else:
|
||||
query = "SELECT key, * FROM cache_entries WHERE namespace = ?"
|
||||
params = (namespace,)
|
||||
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
key = row["key"]
|
||||
entry = self._row_to_entry(row)
|
||||
if not self._check_expired(entry):
|
||||
results.append((key, entry))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to iterate entries: {e}") from e
|
||||
|
||||
async def find_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
threshold: float,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Optional[tuple[str, CacheEntry, float]]:
|
||||
"""Find semantically similar cached entry.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
threshold: Minimum similarity score (0-1)
|
||||
namespace: Optional namespace filter
|
||||
|
||||
Returns:
|
||||
(key, entry, similarity) tuple if found above threshold, None otherwise
|
||||
"""
|
||||
try:
|
||||
entries = await self.iterate(namespace)
|
||||
candidates = [(k, v) for k, v in entries if v.embedding is not None]
|
||||
return self._find_best_match(candidates, embedding, threshold)
|
||||
except Exception as e:
|
||||
raise CacheBackendError(f"Failed to find similar entry: {e}") from e
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get backend statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with size, database path, hits, misses
|
||||
"""
|
||||
base_stats = await super().get_stats()
|
||||
|
||||
try:
|
||||
conn = await self._get_conn()
|
||||
async with conn.execute("SELECT COUNT(*) FROM cache_entries") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
size = row[0] if row else 0
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"size": size,
|
||||
"db_path": self._db_path,
|
||||
}
|
||||
except Exception as e:
|
||||
return {**base_stats, "size": 0, "db_path": self._db_path, "error": str(e)}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._conn is not None:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
Loading…
Add table
Add a link
Reference in a new issue