async-semantic-llm-cache/tests/test_similarity.py

209 lines
7.3 KiB
Python
Raw Normal View History

2026-03-06 15:54:47 +01:00
"""Tests for embedding generation and similarity matching."""
import numpy as np
import pytest
from semantic_llm_cache.similarity import (
DummyEmbeddingProvider,
EmbeddingCache,
OpenAIEmbeddingProvider,
SentenceTransformerProvider,
cosine_similarity,
create_embedding_provider,
)
class TestCosineSimilarity:
"""Tests for cosine_similarity function."""
def test_identical_vectors(self):
"""Test identical vectors have similarity 1.0."""
a = [1.0, 2.0, 3.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == pytest.approx(1.0)
def test_orthogonal_vectors(self):
"""Test orthogonal vectors have similarity 0.0."""
a = [1.0, 0.0, 0.0]
b = [0.0, 1.0, 0.0]
assert cosine_similarity(a, b) == pytest.approx(0.0)
def test_opposite_vectors(self):
"""Test opposite vectors have similarity -1.0."""
a = [1.0, 2.0, 3.0]
b = [-1.0, -2.0, -3.0]
assert cosine_similarity(a, b) == pytest.approx(-1.0)
def test_zero_vectors(self):
"""Test zero vectors return 0.0."""
a = [0.0, 0.0, 0.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == 0.0
def test_numpy_array_input(self):
"""Test function accepts numpy arrays."""
a = np.array([1.0, 2.0, 3.0])
b = np.array([1.0, 2.0, 3.0])
assert cosine_similarity(a, b) == pytest.approx(1.0)
def test_mixed_dimensions(self):
"""Test vectors of different dimensions raise error."""
a = [1.0, 2.0]
b = [1.0, 2.0, 3.0]
# Should raise ValueError for mismatched dimensions
with pytest.raises(ValueError, match="dimension mismatch"):
cosine_similarity(a, b)
class TestDummyEmbeddingProvider:
"""Tests for DummyEmbeddingProvider."""
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = DummyEmbeddingProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = DummyEmbeddingProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
assert e1 == e2
def test_encode_different_inputs(self):
"""Test different inputs produce different outputs."""
provider = DummyEmbeddingProvider()
e1 = provider.encode("prompt 1")
e2 = provider.encode("prompt 2")
assert e1 != e2
def test_custom_dimension(self):
"""Test custom embedding dimension."""
provider = DummyEmbeddingProvider(dim=128)
embedding = provider.encode("test")
assert len(embedding) == 128
def test_embedding_normalized(self):
"""Test embeddings are normalized to unit length."""
provider = DummyEmbeddingProvider()
embedding = provider.encode("test prompt")
# Calculate norm
norm = np.linalg.norm(embedding)
assert norm == pytest.approx(1.0, rel=1e-5)
class TestSentenceTransformerProvider:
"""Tests for SentenceTransformerProvider."""
@pytest.mark.skip(reason="Requires sentence-transformers installation")
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = SentenceTransformerProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
@pytest.mark.skip(reason="Requires sentence-transformers installation")
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = SentenceTransformerProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
assert e1 == e2
def test_import_error_without_package(self, monkeypatch):
"""Test ImportError raised when package not installed."""
# Skip if sentence-transformers is installed
pytest.importorskip("sentence_transformers", reason="sentence-transformers is installed, cannot test import error")
class TestOpenAIEmbeddingProvider:
"""Tests for OpenAIEmbeddingProvider."""
@pytest.mark.skip(reason="Requires OpenAI API key")
def test_encode_returns_list(self):
"""Test encode returns list of floats."""
provider = OpenAIEmbeddingProvider()
embedding = provider.encode("test prompt")
assert isinstance(embedding, list)
assert all(isinstance(x, float) for x in embedding)
@pytest.mark.skip(reason="Requires OpenAI API key")
def test_encode_deterministic(self):
"""Test same input produces same output."""
provider = OpenAIEmbeddingProvider()
text = "test prompt"
e1 = provider.encode(text)
e2 = provider.encode(text)
# OpenAI embeddings may vary slightly
assert len(e1) == len(e2)
def test_import_error_without_package(self, monkeypatch):
"""Test ImportError raised when package not installed."""
# Skip if openai is installed
pytest.importorskip("openai", reason="openai is installed, cannot test import error")
class TestEmbeddingCache:
"""Tests for EmbeddingCache."""
def test_cache_provider(self):
"""Test cache uses provider."""
provider = DummyEmbeddingProvider()
cache = EmbeddingCache(provider=provider)
# Encode same text twice
e1 = cache.encode("test prompt")
e2 = cache.encode("test prompt")
assert e1 == e2
def test_cache_clear(self):
"""Test cache can be cleared."""
provider = DummyEmbeddingProvider()
cache = EmbeddingCache(provider=provider)
cache.encode("test prompt")
cache.clear_cache()
# Should still work after clear
embedding = cache.encode("test prompt")
assert len(embedding) > 0
def test_cache_default_provider(self):
"""Test cache uses dummy provider by default."""
cache = EmbeddingCache()
embedding = cache.encode("test prompt")
assert isinstance(embedding, list)
assert len(embedding) > 0
class TestCreateEmbeddingProvider:
"""Tests for create_embedding_provider factory."""
def test_create_dummy_provider(self):
"""Test creating dummy provider."""
provider = create_embedding_provider("dummy")
assert isinstance(provider, DummyEmbeddingProvider)
def test_create_auto_provider(self):
"""Test auto provider creates sentence-transformers when available."""
provider = create_embedding_provider("auto")
# Creates SentenceTransformerProvider if available, else DummyEmbeddingProvider
assert isinstance(provider, (DummyEmbeddingProvider, SentenceTransformerProvider))
def test_invalid_provider_type(self):
"""Test invalid provider type raises error."""
with pytest.raises(ValueError, match="Unknown provider type"):
create_embedding_provider("invalid_type")
def test_custom_model_name(self, monkeypatch):
"""Test custom model name is passed through."""
# This would work with actual sentence-transformers
provider = create_embedding_provider("dummy", model_name="custom-model")
assert isinstance(provider, DummyEmbeddingProvider)