async-semantic-llm-cache/semantic_llm_cache/similarity.py

284 lines
8.1 KiB
Python
Raw Normal View History

2026-03-06 15:54:47 +01:00
"""Embedding generation and similarity matching for llm-semantic-cache."""
import asyncio
import hashlib
from functools import lru_cache
from typing import Optional
import numpy as np
from semantic_llm_cache.exceptions import PromptCacheError
class EmbeddingProvider:
"""Base class for embedding providers."""
def encode(self, text: str) -> list[float]:
"""Generate embedding for text.
Args:
text: Input text to encode
Returns:
Embedding vector as list of floats
"""
raise NotImplementedError
class DummyEmbeddingProvider(EmbeddingProvider):
"""Fallback embedding provider using hash-based vectors.
Provides consistent embeddings without external dependencies.
Not semantically meaningful but provides consistent cache keys.
"""
def __init__(self, dim: int = 384) -> None:
"""Initialize dummy provider.
Args:
dim: Embedding dimension (matches MiniLM default)
"""
self._dim = dim
def encode(self, text: str) -> list[float]:
"""Generate hash-based embedding for text.
Args:
text: Input text to encode
Returns:
Deterministic embedding vector based on text hash
"""
hash_obj = hashlib.sha256(text.encode())
hash_bytes = hash_obj.digest()
values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype(
np.float32
)
if len(values) < self._dim:
values = np.pad(values, (0, self._dim - len(values)))
norm = np.linalg.norm(values)
if norm > 0:
values = values / norm
return values.tolist()
class SentenceTransformerProvider(EmbeddingProvider):
"""Sentence-transformers based embedding provider.
Uses local models like MiniLM for semantic embeddings.
Inference is CPU/GPU-bound; use aencode() from async contexts.
"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
"""Initialize sentence-transformer provider.
Args:
model_name: Name of sentence-transformer model
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise PromptCacheError(
"sentence-transformers package required for semantic matching. "
"Install with: pip install semantic-llm-cache[semantic]"
) from e
self._model = SentenceTransformer(model_name)
self._dim = self._model.get_sentence_embedding_dimension()
def encode(self, text: str) -> list[float]:
"""Generate embedding for text (blocking — use aencode from async code).
Args:
text: Input text to encode
Returns:
Normalized embedding vector
"""
embedding = self._model.encode(text, convert_to_numpy=True)
embedding = np.asarray(embedding, dtype=np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding.tolist()
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI API-based embedding provider.
Uses OpenAI's embedding API for high-quality semantic embeddings.
Network I/O always use aencode() from async contexts.
"""
def __init__(
self, api_key: Optional[str] = None, model: str = "text-embedding-3-small"
) -> None:
"""Initialize OpenAI embedding provider.
Args:
api_key: OpenAI API key (uses OPENAI_API_KEY env var if None)
model: OpenAI embedding model to use
"""
try:
import openai
except ImportError as e:
raise PromptCacheError(
"openai package required for OpenAI embeddings. "
"Install with: pip install semantic-llm-cache[openai]"
) from e
self._client = openai.OpenAI(api_key=api_key)
self._model = model
def encode(self, text: str) -> list[float]:
"""Generate embedding for text (blocking — use aencode from async code).
Args:
text: Input text to encode
Returns:
OpenAI embedding vector (already normalized)
"""
response = self._client.embeddings.create(input=text, model=self._model)
embedding = response.data[0].embedding
embedding_arr = np.asarray(embedding, dtype=np.float32)
norm = np.linalg.norm(embedding_arr)
if norm > 0:
embedding_arr = embedding_arr / norm
return embedding_arr.tolist()
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
"""Calculate cosine similarity between two vectors.
Args:
a: First vector
b: Second vector
Returns:
Similarity score between 0 and 1
Raises:
ValueError: If vectors have different dimensions
"""
a_arr = np.asarray(a, dtype=np.float32)
b_arr = np.asarray(b, dtype=np.float32)
if a_arr.shape != b_arr.shape:
raise ValueError(
f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}"
)
dot_product = np.dot(a_arr, b_arr)
norm_a = np.linalg.norm(a_arr)
norm_b = np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(dot_product / (norm_a * norm_b))
def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]:
"""Helper function for LRU cache encoding.
Args:
text: Input text
provider: Embedding provider
Returns:
Embedding as tuple for hashability
"""
return tuple(provider.encode(text))
class EmbeddingCache:
"""Cache for embedding generation with LRU eviction.
Use encode() from sync contexts, aencode() from async contexts.
aencode() offloads blocking inference to a thread pool via asyncio.to_thread.
"""
def __init__(
self,
provider: Optional[EmbeddingProvider] = None,
cache_size: int = 1024,
) -> None:
"""Initialize embedding cache.
Args:
provider: Embedding provider (uses DummyEmbeddingProvider if None)
cache_size: Maximum number of embeddings to cache
"""
self._provider = provider or DummyEmbeddingProvider()
self._cache_size = cache_size
self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider)
def encode(self, text: str) -> list[float]:
"""Generate embedding with LRU caching (sync, blocking).
Args:
text: Input text to encode
Returns:
Embedding vector
"""
return list(self._get_cached(text, self._provider))
async def aencode(self, text: str) -> list[float]:
"""Generate embedding with LRU caching (async, non-blocking).
CPU/network-bound work is offloaded to the default thread pool via
asyncio.to_thread, keeping the event loop free.
Args:
text: Input text to encode
Returns:
Embedding vector
"""
return await asyncio.to_thread(self.encode, text)
def clear_cache(self) -> None:
"""Clear the embedding LRU cache."""
self._get_cached.cache_clear()
def create_embedding_provider(
provider_type: str = "auto",
model_name: Optional[str] = None,
) -> EmbeddingProvider:
"""Create embedding provider based on type.
Args:
provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy")
model_name: Optional model name to use
Returns:
EmbeddingProvider instance
"""
if provider_type == "auto":
try:
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
except PromptCacheError:
return DummyEmbeddingProvider()
if provider_type == "sentence-transformer":
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
if provider_type == "openai":
return OpenAIEmbeddingProvider(model=model_name)
if provider_type == "dummy":
return DummyEmbeddingProvider()
raise ValueError(f"Unknown provider type: {provider_type}")