trustgraph/tests/unit/test_embeddings/test_embedding_logic.py

278 lines
9.4 KiB
Python
Raw Normal View History

Release/v1.2 (#457) * Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
2025-08-18 20:56:09 +01:00
"""
Unit tests for embedding business logic
Tests the core embedding functionality without external dependencies,
focusing on data processing, validation, and business rules.
"""
import pytest
import numpy as np
from unittest.mock import Mock, patch
class TestEmbeddingBusinessLogic:
"""Test embedding business logic and data processing"""
def test_embedding_vector_validation(self):
"""Test validation of embedding vectors"""
# Arrange
valid_vectors = [
[0.1, 0.2, 0.3],
[-0.5, 0.0, 0.8],
[], # Empty vector
[1.0] * 1536 # Large vector
]
invalid_vectors = [
None,
"not a vector",
[1, 2, "string"],
[[1, 2], [3, 4]] # Nested
]
# Act & Assert
def is_valid_vector(vec):
if not isinstance(vec, list):
return False
return all(isinstance(x, (int, float)) for x in vec)
for vec in valid_vectors:
assert is_valid_vector(vec), f"Should be valid: {vec}"
for vec in invalid_vectors:
assert not is_valid_vector(vec), f"Should be invalid: {vec}"
def test_dimension_consistency_check(self):
"""Test dimension consistency validation"""
# Arrange
same_dimension_vectors = [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0],
[-0.1, -0.2, -0.3, -0.4, -0.5]
]
mixed_dimension_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6, 0.7],
[0.8, 0.9]
]
# Act
def check_dimension_consistency(vectors):
if not vectors:
return True
expected_dim = len(vectors[0])
return all(len(vec) == expected_dim for vec in vectors)
# Assert
assert check_dimension_consistency(same_dimension_vectors)
assert not check_dimension_consistency(mixed_dimension_vectors)
def test_text_preprocessing_logic(self):
"""Test text preprocessing for embeddings"""
# Arrange
test_cases = [
("Simple text", "Simple text"),
("", ""),
("Text with\nnewlines", "Text with\nnewlines"),
("Unicode: 世界 🌍", "Unicode: 世界 🌍"),
(" Whitespace ", " Whitespace ")
]
# Act & Assert
for input_text, expected in test_cases:
# Simple preprocessing (identity in this case)
processed = str(input_text) if input_text is not None else ""
assert processed == expected
def test_batch_processing_logic(self):
"""Test batch processing logic for multiple texts"""
# Arrange
texts = ["Text 1", "Text 2", "Text 3"]
def mock_embed_single(text):
# Simulate embedding generation based on text length
return [len(text) / 10.0] * 5
# Act
results = []
for text in texts:
embedding = mock_embed_single(text)
results.append((text, embedding))
# Assert
assert len(results) == len(texts)
for i, (original_text, embedding) in enumerate(results):
assert original_text == texts[i]
assert len(embedding) == 5
expected_value = len(texts[i]) / 10.0
assert all(abs(val - expected_value) < 0.001 for val in embedding)
def test_numpy_array_conversion_logic(self):
"""Test numpy array to list conversion"""
# Arrange
test_arrays = [
np.array([1, 2, 3], dtype=np.int32),
np.array([1.0, 2.0, 3.0], dtype=np.float64),
np.array([0.1, 0.2, 0.3], dtype=np.float32)
]
# Act
converted = []
for arr in test_arrays:
result = arr.tolist()
converted.append(result)
# Assert
assert converted[0] == [1, 2, 3]
assert converted[1] == [1.0, 2.0, 3.0]
# Float32 might have precision differences, so check approximately
assert len(converted[2]) == 3
assert all(isinstance(x, float) for x in converted[2])
def test_error_response_generation(self):
"""Test error response generation logic"""
# Arrange
error_scenarios = [
("model_not_found", "Model 'xyz' not found"),
("connection_error", "Failed to connect to service"),
("rate_limit", "Rate limit exceeded"),
("invalid_input", "Invalid input format")
]
# Act & Assert
for error_type, error_message in error_scenarios:
error_response = {
"error": {
"type": error_type,
"message": error_message
},
"vectors": None
}
assert error_response["error"]["type"] == error_type
assert error_response["error"]["message"] == error_message
assert error_response["vectors"] is None
def test_success_response_generation(self):
"""Test success response generation logic"""
# Arrange
test_vectors = [0.1, 0.2, 0.3, 0.4, 0.5]
# Act
success_response = {
"error": None,
"vectors": test_vectors
}
# Assert
assert success_response["error"] is None
assert success_response["vectors"] == test_vectors
assert len(success_response["vectors"]) == 5
def test_model_parameter_handling(self):
"""Test model parameter validation and handling"""
# Arrange
valid_models = {
"ollama": ["mxbai-embed-large", "nomic-embed-text"],
"fastembed": ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
}
# Act & Assert
for provider, models in valid_models.items():
for model in models:
assert isinstance(model, str)
assert len(model) > 0
if provider == "fastembed":
assert "/" in model or "-" in model
def test_concurrent_processing_simulation(self):
"""Test concurrent processing simulation"""
# Arrange
import asyncio
async def mock_async_embed(text, delay=0.001):
await asyncio.sleep(delay)
return [ord(text[0]) / 255.0] if text else [0.0]
# Act
async def run_concurrent():
texts = ["A", "B", "C", "D", "E"]
tasks = [mock_async_embed(text) for text in texts]
results = await asyncio.gather(*tasks)
return list(zip(texts, results))
# Run test
results = asyncio.run(run_concurrent())
# Assert
assert len(results) == 5
for i, (text, embedding) in enumerate(results):
expected_char = chr(ord('A') + i)
assert text == expected_char
expected_value = ord(expected_char) / 255.0
assert abs(embedding[0] - expected_value) < 0.001
def test_empty_and_edge_cases(self):
"""Test empty inputs and edge cases"""
# Arrange
edge_cases = [
("", "empty string"),
(" ", "single space"),
("a", "single character"),
("A" * 10000, "very long string"),
("\\n\\t\\r", "special characters"),
("混合English中文", "mixed languages")
]
# Act & Assert
for text, description in edge_cases:
# Basic validation that text can be processed
assert isinstance(text, str), f"Failed for {description}"
assert len(text) >= 0, f"Failed for {description}"
# Simulate embedding generation would work
mock_embedding = [len(text) % 10] * 3
assert len(mock_embedding) == 3, f"Failed for {description}"
def test_vector_normalization_logic(self):
"""Test vector normalization calculations"""
# Arrange
test_vectors = [
[3.0, 4.0], # Should normalize to [0.6, 0.8]
[1.0, 0.0], # Should normalize to [1.0, 0.0]
[0.0, 0.0], # Zero vector edge case
]
# Act & Assert
for vector in test_vectors:
magnitude = sum(x**2 for x in vector) ** 0.5
if magnitude > 0:
normalized = [x / magnitude for x in vector]
# Check unit length (approximately)
norm_magnitude = sum(x**2 for x in normalized) ** 0.5
assert abs(norm_magnitude - 1.0) < 0.0001
else:
# Zero vector case
assert all(x == 0 for x in vector)
def test_cosine_similarity_calculation(self):
"""Test cosine similarity computation"""
# Arrange
vector_pairs = [
([1, 0], [0, 1], 0.0), # Orthogonal
([1, 0], [1, 0], 1.0), # Identical
([1, 1], [-1, -1], -1.0), # Opposite
]
# Act & Assert
def cosine_similarity(v1, v2):
dot = sum(a * b for a, b in zip(v1, v2))
mag1 = sum(x**2 for x in v1) ** 0.5
mag2 = sum(x**2 for x in v2) ** 0.5
return dot / (mag1 * mag2) if mag1 * mag2 > 0 else 0
for v1, v2, expected in vector_pairs:
similarity = cosine_similarity(v1, v2)
assert abs(similarity - expected) < 0.0001