Add files via upload
initial commit
This commit is contained in:
parent
8d3d5ff628
commit
b33bb415dd
24 changed files with 4840 additions and 0 deletions
369
semantic_llm_cache/core.py
Normal file
369
semantic_llm_cache/core.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
"""Core cache decorator and API for llm-semantic-cache."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
from semantic_llm_cache.backends import MemoryBackend
|
||||
from semantic_llm_cache.backends.base import BaseBackend
|
||||
from semantic_llm_cache.config import CacheConfig, CacheEntry
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
from semantic_llm_cache.similarity import EmbeddingCache
|
||||
from semantic_llm_cache.stats import _stats_manager
|
||||
from semantic_llm_cache.utils import hash_prompt, normalize_prompt
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _extract_prompt(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
"""Extract prompt string from function arguments."""
|
||||
if args and isinstance(args[0], str):
|
||||
return args[0]
|
||||
if "prompt" in kwargs:
|
||||
return str(kwargs["prompt"])
|
||||
return str(args) + str(sorted(kwargs.items()))
|
||||
|
||||
|
||||
class CacheContext:
|
||||
"""Context manager for cache configuration.
|
||||
|
||||
Supports both sync (with) and async (async with) usage.
|
||||
|
||||
Examples:
|
||||
>>> async with CacheContext(similarity=0.9) as ctx:
|
||||
... result = await llm_call("prompt")
|
||||
... print(ctx.stats)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity: Optional[float] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> None:
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity if similarity is not None else 1.0,
|
||||
ttl=ttl,
|
||||
namespace=namespace if namespace is not None else "default",
|
||||
enabled=enabled if enabled is not None else True,
|
||||
)
|
||||
self._stats: dict[str, Any] = {"hits": 0, "misses": 0}
|
||||
|
||||
def __enter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "CacheContext":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def stats(self) -> dict[str, Any]:
|
||||
return self._stats.copy()
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
return self._config
|
||||
|
||||
|
||||
class CachedLLM:
|
||||
"""Wrapper class for LLM calls with automatic caching.
|
||||
|
||||
Examples:
|
||||
>>> llm = CachedLLM(similarity=0.9)
|
||||
>>> response = await llm.achat("What is Python?", llm_func=my_async_llm)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: str = "gpt-4",
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._backend = backend or MemoryBackend()
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
self._config = CacheConfig(
|
||||
similarity_threshold=similarity,
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_func: Optional[Callable[[str], Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Get response with caching (async).
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
llm_func: Async or sync LLM function to call on cache miss
|
||||
**kwargs: Additional arguments for llm_func
|
||||
|
||||
Returns:
|
||||
LLM response (cached or fresh)
|
||||
"""
|
||||
if llm_func is None:
|
||||
raise ValueError("llm_func is required for CachedLLM.achat()")
|
||||
|
||||
@cache(
|
||||
similarity=self._config.similarity_threshold,
|
||||
ttl=self._config.ttl,
|
||||
backend=self._backend,
|
||||
namespace=self._config.namespace,
|
||||
enabled=self._config.enabled,
|
||||
)
|
||||
async def _cached_call(p: str) -> Any:
|
||||
result = llm_func(p, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
return await _cached_call(prompt)
|
||||
|
||||
|
||||
def cache(
|
||||
similarity: float = 1.0,
|
||||
ttl: Optional[int] = 3600,
|
||||
backend: Optional[BaseBackend] = None,
|
||||
namespace: str = "default",
|
||||
enabled: bool = True,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator for caching LLM function responses.
|
||||
|
||||
Auto-detects whether the decorated function is async or sync and returns
|
||||
the appropriate wrapper. Both variants share identical cache logic.
|
||||
|
||||
Async functions get a true async wrapper (awaits all backend calls).
|
||||
Sync functions get a sync wrapper that drives the async backends via a
|
||||
temporary event loop — not suitable inside a running loop; prefer decorating
|
||||
async functions when integrating with async frameworks like FastAPI.
|
||||
|
||||
Args:
|
||||
similarity: Cosine similarity threshold (1.0=exact, 0.9=semantic)
|
||||
ttl: Time-to-live in seconds (None=forever)
|
||||
backend: Async storage backend (None=in-memory)
|
||||
namespace: Cache namespace for isolation
|
||||
enabled: Whether caching is enabled
|
||||
key_func: Custom cache key function
|
||||
|
||||
Returns:
|
||||
Decorated function with caching
|
||||
|
||||
Examples:
|
||||
>>> @cache(similarity=0.9, ttl=3600)
|
||||
... async def ask_llm(prompt: str) -> str:
|
||||
... return await call_ollama(prompt)
|
||||
|
||||
>>> @cache()
|
||||
... def ask_llm_sync(prompt: str) -> str:
|
||||
... return call_ollama_sync(prompt)
|
||||
"""
|
||||
_backend = backend or MemoryBackend()
|
||||
embedding_cache = EmbeddingCache()
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
# ── Async wrapper ────────────────────────────────────────────────
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
# 1. Exact match
|
||||
entry = await _backend.get(cache_key)
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = await embedding_cache.aencode(normalized)
|
||||
result = await _backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss — call through
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = await embedding_cache.aencode(normalized)
|
||||
|
||||
await _backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
else:
|
||||
# ── Sync wrapper (backwards compatibility) ───────────────────────
|
||||
# Drives async backends via a dedicated event loop per call.
|
||||
# Do NOT use inside a running event loop (e.g. FastAPI handlers).
|
||||
import asyncio
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
prompt = _extract_prompt(args, kwargs) # type: ignore[arg-type]
|
||||
normalized = normalize_prompt(prompt)
|
||||
cache_key = (
|
||||
key_func(*args, **kwargs) # type: ignore[arg-type]
|
||||
if key_func
|
||||
else hash_prompt(normalized, namespace)
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# 1. Exact match
|
||||
entry = loop.run_until_complete(_backend.get(cache_key))
|
||||
if entry is not None:
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return entry.response # type: ignore[return-value]
|
||||
|
||||
# 2. Semantic match
|
||||
if similarity < 1.0:
|
||||
query_embedding = embedding_cache.encode(normalized)
|
||||
result = loop.run_until_complete(
|
||||
_backend.find_similar(
|
||||
query_embedding, threshold=similarity, namespace=namespace
|
||||
)
|
||||
)
|
||||
if result is not None:
|
||||
_, matched_entry, _ = result
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
_stats_manager.record_hit(
|
||||
namespace,
|
||||
latency_saved_ms=latency_ms,
|
||||
saved_cost=matched_entry.estimate_cost(0.001, 0.002),
|
||||
)
|
||||
return matched_entry.response # type: ignore[return-value]
|
||||
|
||||
# 3. Cache miss
|
||||
_stats_manager.record_miss(namespace)
|
||||
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise PromptCacheError(f"LLM function call failed: {e}") from e
|
||||
|
||||
embedding = None
|
||||
if similarity < 1.0:
|
||||
embedding = embedding_cache.encode(normalized)
|
||||
|
||||
loop.run_until_complete(
|
||||
_backend.set(
|
||||
cache_key,
|
||||
CacheEntry(
|
||||
prompt=normalized,
|
||||
response=response,
|
||||
embedding=embedding,
|
||||
created_at=time.time(),
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
hit_count=0,
|
||||
input_tokens=len(normalized) // 4,
|
||||
output_tokens=len(str(response)) // 4,
|
||||
),
|
||||
)
|
||||
)
|
||||
return response # type: ignore[return-value]
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global default backend for utility functions
|
||||
_default_backend: Optional[BaseBackend] = None
|
||||
|
||||
|
||||
def get_default_backend() -> BaseBackend:
|
||||
"""Get default storage backend."""
|
||||
global _default_backend
|
||||
if _default_backend is None:
|
||||
_default_backend = MemoryBackend()
|
||||
return _default_backend
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend) -> None:
|
||||
"""Set default storage backend."""
|
||||
global _default_backend
|
||||
_default_backend = backend
|
||||
_stats_manager.set_backend(backend)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cache",
|
||||
"CacheContext",
|
||||
"CachedLLM",
|
||||
"get_default_backend",
|
||||
"set_default_backend",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue