Add files via upload
initial commit
This commit is contained in:
parent
8d3d5ff628
commit
b33bb415dd
24 changed files with 4840 additions and 0 deletions
283
semantic_llm_cache/similarity.py
Normal file
283
semantic_llm_cache/similarity.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Embedding generation and similarity matching for llm-semantic-cache."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from semantic_llm_cache.exceptions import PromptCacheError
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyEmbeddingProvider(EmbeddingProvider):
|
||||
"""Fallback embedding provider using hash-based vectors.
|
||||
|
||||
Provides consistent embeddings without external dependencies.
|
||||
Not semantically meaningful but provides consistent cache keys.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 384) -> None:
|
||||
"""Initialize dummy provider.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (matches MiniLM default)
|
||||
"""
|
||||
self._dim = dim
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate hash-based embedding for text.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Deterministic embedding vector based on text hash
|
||||
"""
|
||||
hash_obj = hashlib.sha256(text.encode())
|
||||
hash_bytes = hash_obj.digest()
|
||||
|
||||
values = np.frombuffer(hash_bytes, dtype=np.uint8)[: self._dim].astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
if len(values) < self._dim:
|
||||
values = np.pad(values, (0, self._dim - len(values)))
|
||||
|
||||
norm = np.linalg.norm(values)
|
||||
if norm > 0:
|
||||
values = values / norm
|
||||
|
||||
return values.tolist()
|
||||
|
||||
|
||||
class SentenceTransformerProvider(EmbeddingProvider):
|
||||
"""Sentence-transformers based embedding provider.
|
||||
|
||||
Uses local models like MiniLM for semantic embeddings.
|
||||
Inference is CPU/GPU-bound; use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
|
||||
"""Initialize sentence-transformer provider.
|
||||
|
||||
Args:
|
||||
model_name: Name of sentence-transformer model
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"sentence-transformers package required for semantic matching. "
|
||||
"Install with: pip install semantic-llm-cache[semantic]"
|
||||
) from e
|
||||
|
||||
self._model = SentenceTransformer(model_name)
|
||||
self._dim = self._model.get_sentence_embedding_dimension()
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Normalized embedding vector
|
||||
"""
|
||||
embedding = self._model.encode(text, convert_to_numpy=True)
|
||||
embedding = np.asarray(embedding, dtype=np.float32)
|
||||
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI API-based embedding provider.
|
||||
|
||||
Uses OpenAI's embedding API for high-quality semantic embeddings.
|
||||
Network I/O — always use aencode() from async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, model: str = "text-embedding-3-small"
|
||||
) -> None:
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key (uses OPENAI_API_KEY env var if None)
|
||||
model: OpenAI embedding model to use
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise PromptCacheError(
|
||||
"openai package required for OpenAI embeddings. "
|
||||
"Install with: pip install semantic-llm-cache[openai]"
|
||||
) from e
|
||||
|
||||
self._client = openai.OpenAI(api_key=api_key)
|
||||
self._model = model
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding for text (blocking — use aencode from async code).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
OpenAI embedding vector (already normalized)
|
||||
"""
|
||||
response = self._client.embeddings.create(input=text, model=self._model)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
embedding_arr = np.asarray(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(embedding_arr)
|
||||
if norm > 0:
|
||||
embedding_arr = embedding_arr / norm
|
||||
|
||||
return embedding_arr.tolist()
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float] | np.ndarray, b: list[float] | np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
a: First vector
|
||||
b: Second vector
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
|
||||
Raises:
|
||||
ValueError: If vectors have different dimensions
|
||||
"""
|
||||
a_arr = np.asarray(a, dtype=np.float32)
|
||||
b_arr = np.asarray(b, dtype=np.float32)
|
||||
|
||||
if a_arr.shape != b_arr.shape:
|
||||
raise ValueError(
|
||||
f"Vector dimension mismatch: {a_arr.shape} != {b_arr.shape}"
|
||||
)
|
||||
|
||||
dot_product = np.dot(a_arr, b_arr)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return float(dot_product / (norm_a * norm_b))
|
||||
|
||||
|
||||
def _encode_with_provider(text: str, provider: EmbeddingProvider) -> tuple[float, ...]:
|
||||
"""Helper function for LRU cache encoding.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
provider: Embedding provider
|
||||
|
||||
Returns:
|
||||
Embedding as tuple for hashability
|
||||
"""
|
||||
return tuple(provider.encode(text))
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embedding generation with LRU eviction.
|
||||
|
||||
Use encode() from sync contexts, aencode() from async contexts.
|
||||
aencode() offloads blocking inference to a thread pool via asyncio.to_thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: Optional[EmbeddingProvider] = None,
|
||||
cache_size: int = 1024,
|
||||
) -> None:
|
||||
"""Initialize embedding cache.
|
||||
|
||||
Args:
|
||||
provider: Embedding provider (uses DummyEmbeddingProvider if None)
|
||||
cache_size: Maximum number of embeddings to cache
|
||||
"""
|
||||
self._provider = provider or DummyEmbeddingProvider()
|
||||
self._cache_size = cache_size
|
||||
self._get_cached = lru_cache(maxsize=cache_size)(_encode_with_provider)
|
||||
|
||||
def encode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (sync, blocking).
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return list(self._get_cached(text, self._provider))
|
||||
|
||||
async def aencode(self, text: str) -> list[float]:
|
||||
"""Generate embedding with LRU caching (async, non-blocking).
|
||||
|
||||
CPU/network-bound work is offloaded to the default thread pool via
|
||||
asyncio.to_thread, keeping the event loop free.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
return await asyncio.to_thread(self.encode, text)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the embedding LRU cache."""
|
||||
self._get_cached.cache_clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "auto",
|
||||
model_name: Optional[str] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create embedding provider based on type.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider ("auto", "sentence-transformer", "openai", "dummy")
|
||||
model_name: Optional model name to use
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider_type == "auto":
|
||||
try:
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
except PromptCacheError:
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
if provider_type == "sentence-transformer":
|
||||
return SentenceTransformerProvider(model_name or "all-MiniLM-L6-v2")
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(model=model_name)
|
||||
|
||||
if provider_type == "dummy":
|
||||
return DummyEmbeddingProvider()
|
||||
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue