trustgraph/tests/unit/test_embeddings/test_embedding_utils.py

340 lines
12 KiB
Python
Raw Normal View History

Release/v1.2 (#457) * Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
2025-08-18 20:56:09 +01:00
"""
Unit tests for embedding utilities and common functionality
Tests dimension consistency, batch processing, error handling patterns,
and other utilities common across embedding services.
"""
import pytest
from unittest.mock import patch, Mock, AsyncMock
import numpy as np
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
from trustgraph.exceptions import TooManyRequests
class MockEmbeddingProcessor:
"""Simple mock embedding processor for testing functionality"""
def __init__(self, embedding_function=None, **params):
# Store embedding function for mocking
self.embedding_function = embedding_function
self.model = params.get('model', 'test-model')
async def on_embeddings(self, text):
if self.embedding_function:
return self.embedding_function(text)
return [0.1, 0.2, 0.3, 0.4, 0.5] # Default test embedding
class TestEmbeddingDimensionConsistency:
"""Test cases for embedding dimension consistency"""
async def test_consistent_dimensions_single_processor(self):
"""Test that a single processor returns consistent dimensions"""
# Arrange
dimension = 128
def mock_embedding(text):
return [0.1] * dimension
processor = MockEmbeddingProcessor(embedding_function=mock_embedding)
# Act
results = []
test_texts = ["Text 1", "Text 2", "Text 3", "Text 4", "Text 5"]
for text in test_texts:
result = await processor.on_embeddings(text)
results.append(result)
# Assert
for result in results:
assert len(result) == dimension, f"Expected dimension {dimension}, got {len(result)}"
# All results should have same dimensions
first_dim = len(results[0])
for i, result in enumerate(results[1:], 1):
assert len(result) == first_dim, f"Dimension mismatch at index {i}"
async def test_dimension_consistency_across_text_lengths(self):
"""Test dimension consistency across varying text lengths"""
# Arrange
dimension = 384
def mock_embedding(text):
# Dimension should not depend on text length
return [0.1] * dimension
processor = MockEmbeddingProcessor(embedding_function=mock_embedding)
# Act - Test various text lengths
test_texts = [
"", # Empty text
"Hi", # Very short
"This is a medium length sentence for testing.", # Medium
"This is a very long text that should still produce embeddings of consistent dimension regardless of the input text length and content." * 10 # Very long
]
results = []
for text in test_texts:
result = await processor.on_embeddings(text)
results.append(result)
# Assert
for i, result in enumerate(results):
assert len(result) == dimension, f"Text length {len(test_texts[i])} produced wrong dimension"
def test_dimension_validation_different_models(self):
"""Test dimension validation for different model configurations"""
# Arrange
models_and_dims = [
("small-model", 128),
("medium-model", 384),
("large-model", 1536)
]
# Act & Assert
for model_name, expected_dim in models_and_dims:
# Test dimension validation logic
test_vector = [0.1] * expected_dim
assert len(test_vector) == expected_dim, f"Model {model_name} dimension mismatch"
class TestEmbeddingBatchProcessing:
"""Test cases for batch processing logic"""
async def test_sequential_processing_maintains_order(self):
"""Test that sequential processing maintains text order"""
# Arrange
def mock_embedding(text):
# Return embedding that encodes the text for verification
return [ord(text[0]) / 255.0] if text else [0.0] # Normalize to [0,1]
processor = MockEmbeddingProcessor(embedding_function=mock_embedding)
# Act
test_texts = ["A", "B", "C", "D", "E"]
results = []
for text in test_texts:
result = await processor.on_embeddings(text)
results.append((text, result))
# Assert
for i, (original_text, embedding) in enumerate(results):
assert original_text == test_texts[i]
expected_value = ord(test_texts[i][0]) / 255.0
assert abs(embedding[0] - expected_value) < 0.001
async def test_batch_processing_throughput(self):
"""Test batch processing capabilities"""
# Arrange
call_count = 0
def mock_embedding(text):
nonlocal call_count
call_count += 1
return [0.1, 0.2, 0.3]
processor = MockEmbeddingProcessor(embedding_function=mock_embedding)
# Act - Process multiple texts
batch_size = 10
test_texts = [f"Text {i}" for i in range(batch_size)]
results = []
for text in test_texts:
result = await processor.on_embeddings(text)
results.append(result)
# Assert
assert call_count == batch_size
assert len(results) == batch_size
for result in results:
assert result == [0.1, 0.2, 0.3]
async def test_concurrent_processing_simulation(self):
"""Test concurrent processing behavior simulation"""
# Arrange
import asyncio
processing_times = []
def mock_embedding(text):
import time
processing_times.append(time.time())
return [len(text) / 100.0] # Encoding text length
processor = MockEmbeddingProcessor(embedding_function=mock_embedding)
# Act - Simulate concurrent processing
test_texts = [f"Text {i}" for i in range(5)]
tasks = [processor.on_embeddings(text) for text in test_texts]
results = await asyncio.gather(*tasks)
# Assert
assert len(results) == 5
assert len(processing_times) == 5
# Results should correspond to text lengths
for i, result in enumerate(results):
expected_value = len(test_texts[i]) / 100.0
assert abs(result[0] - expected_value) < 0.001
class TestEmbeddingErrorHandling:
"""Test cases for error handling in embedding services"""
async def test_embedding_function_error_handling(self):
"""Test error handling in embedding function"""
# Arrange
def failing_embedding(text):
raise Exception("Embedding model failed")
processor = MockEmbeddingProcessor(embedding_function=failing_embedding)
# Act & Assert
with pytest.raises(Exception, match="Embedding model failed"):
await processor.on_embeddings("Test text")
async def test_rate_limit_exception_propagation(self):
"""Test that rate limit exceptions are properly propagated"""
# Arrange
def rate_limited_embedding(text):
raise TooManyRequests("Rate limit exceeded")
processor = MockEmbeddingProcessor(embedding_function=rate_limited_embedding)
# Act & Assert
with pytest.raises(TooManyRequests, match="Rate limit exceeded"):
await processor.on_embeddings("Test text")
async def test_none_result_handling(self):
"""Test handling when embedding function returns None"""
# Arrange
def none_embedding(text):
return None
processor = MockEmbeddingProcessor(embedding_function=none_embedding)
# Act
result = await processor.on_embeddings("Test text")
# Assert
assert result is None
async def test_invalid_embedding_format_handling(self):
"""Test handling of invalid embedding formats"""
# Arrange
def invalid_embedding(text):
return "not a list" # Invalid format
processor = MockEmbeddingProcessor(embedding_function=invalid_embedding)
# Act
result = await processor.on_embeddings("Test text")
# Assert
assert result == "not a list" # Returns what the function provides
class TestEmbeddingUtilities:
"""Test cases for embedding utility functions and helpers"""
def test_vector_normalization_simulation(self):
"""Test vector normalization logic simulation"""
# Arrange
test_vectors = [
[1.0, 2.0, 3.0],
[0.5, -0.5, 1.0],
[10.0, 20.0, 30.0]
]
# Act - Simulate L2 normalization
normalized_vectors = []
for vector in test_vectors:
magnitude = sum(x**2 for x in vector) ** 0.5
if magnitude > 0:
normalized = [x / magnitude for x in vector]
else:
normalized = vector
normalized_vectors.append(normalized)
# Assert
for normalized in normalized_vectors:
magnitude = sum(x**2 for x in normalized) ** 0.5
assert abs(magnitude - 1.0) < 0.0001, "Vector should be unit length"
def test_cosine_similarity_calculation(self):
"""Test cosine similarity calculation between embeddings"""
# Arrange
vector1 = [1.0, 0.0, 0.0]
vector2 = [0.0, 1.0, 0.0]
vector3 = [1.0, 0.0, 0.0] # Same as vector1
# Act - Calculate cosine similarities
def cosine_similarity(v1, v2):
dot_product = sum(a * b for a, b in zip(v1, v2))
mag1 = sum(x**2 for x in v1) ** 0.5
mag2 = sum(x**2 for x in v2) ** 0.5
return dot_product / (mag1 * mag2) if mag1 * mag2 > 0 else 0
sim_12 = cosine_similarity(vector1, vector2)
sim_13 = cosine_similarity(vector1, vector3)
# Assert
assert abs(sim_12 - 0.0) < 0.0001, "Orthogonal vectors should have 0 similarity"
assert abs(sim_13 - 1.0) < 0.0001, "Identical vectors should have 1.0 similarity"
def test_embedding_validation_helpers(self):
"""Test embedding validation helper functions"""
# Arrange
valid_embeddings = [
[0.1, 0.2, 0.3],
[1.0, -1.0, 0.0],
[] # Empty embedding
]
invalid_embeddings = [
None,
"not a list",
[1, 2, "three"], # Mixed types
[[1, 2], [3, 4]] # Nested lists
]
# Act & Assert
def is_valid_embedding(embedding):
if not isinstance(embedding, list):
return False
return all(isinstance(x, (int, float)) for x in embedding)
for embedding in valid_embeddings:
assert is_valid_embedding(embedding), f"Should be valid: {embedding}"
for embedding in invalid_embeddings:
assert not is_valid_embedding(embedding), f"Should be invalid: {embedding}"
async def test_embedding_metadata_handling(self):
"""Test handling of embedding metadata and properties"""
# Arrange
def metadata_embedding(text):
return {
"vectors": [0.1, 0.2, 0.3],
"model": "test-model",
"dimension": 3,
"text_length": len(text)
}
# Mock processor that returns metadata
class MetadataProcessor(MockEmbeddingProcessor):
async def on_embeddings(self, text):
result = metadata_embedding(text)
return result["vectors"] # Return only vectors for compatibility
processor = MetadataProcessor()
# Act
result = await processor.on_embeddings("Test text with metadata")
# Assert
assert isinstance(result, list)
assert len(result) == 3
assert result == [0.1, 0.2, 0.3]