mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Release/v1.2 (#457)
* Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
This commit is contained in:
parent
c85ba197be
commit
89be656990
509 changed files with 49632 additions and 5159 deletions
162
tests/unit/test_storage/conftest.py
Normal file
162
tests/unit/test_storage/conftest.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Shared fixtures for storage tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_storage_config():
|
||||
"""Base configuration for storage processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-storage-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_storage_config(base_storage_config):
|
||||
"""Configuration for Qdrant storage processors"""
|
||||
return base_storage_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.collection_exists.return_value = True
|
||||
mock_client.create_collection.return_value = None
|
||||
mock_client.upsert.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uuid():
|
||||
"""Mock UUID generation"""
|
||||
mock_uuid = MagicMock()
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
return mock_uuid
|
||||
|
||||
|
||||
# Document embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_message():
|
||||
"""Mock document embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_chunks():
|
||||
"""Mock document embeddings message with multiple chunks"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings message with multiple vectors per chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_empty_chunk():
|
||||
"""Mock document embeddings message with empty chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
# Graph embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_message():
|
||||
"""Mock graph embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_entities():
|
||||
"""Mock graph embeddings message with multiple entities"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_empty_entity():
|
||||
"""Mock graph embeddings message with empty entity"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = "" # Empty string
|
||||
mock_entity.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
576
tests/unit/test_storage/test_cassandra_storage_logic.py
Normal file
576
tests/unit/test_storage/test_cassandra_storage_logic.py
Normal file
|
|
@ -0,0 +1,576 @@
|
|||
"""
|
||||
Standalone unit tests for Cassandra Storage Logic
|
||||
|
||||
Tests core Cassandra storage logic without requiring full package imports.
|
||||
This focuses on testing the business logic that would be used by the
|
||||
Cassandra object storage processor components.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import re
|
||||
from unittest.mock import Mock
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class MockField:
|
||||
"""Mock implementation of Field for testing"""
|
||||
|
||||
def __init__(self, name: str, type: str, primary: bool = False,
|
||||
required: bool = False, indexed: bool = False,
|
||||
enum_values: List[str] = None, size: int = 0):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.primary = primary
|
||||
self.required = required
|
||||
self.indexed = indexed
|
||||
self.enum_values = enum_values or []
|
||||
self.size = size
|
||||
|
||||
|
||||
class MockRowSchema:
|
||||
"""Mock implementation of RowSchema for testing"""
|
||||
|
||||
def __init__(self, name: str, description: str, fields: List[MockField]):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.fields = fields
|
||||
|
||||
|
||||
class MockCassandraStorageLogic:
|
||||
"""Mock implementation of Cassandra storage logic for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.known_keyspaces = set()
|
||||
self.known_tables = {} # keyspace -> set of table names
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Cassandra compatibility (keyspaces)"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Ensure it starts with a letter
|
||||
if safe_name and not safe_name[0].isalpha():
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def sanitize_table(self, name: str) -> str:
|
||||
"""Sanitize table names for Cassandra compatibility"""
|
||||
# Replace non-alphanumeric characters with underscore
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
# Always prefix tables with o_
|
||||
safe_name = 'o_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
|
||||
"""Convert schema field type to Cassandra type"""
|
||||
# Handle None size
|
||||
if size is None:
|
||||
size = 0
|
||||
|
||||
type_mapping = {
|
||||
"string": "text",
|
||||
"integer": "bigint" if size > 4 else "int",
|
||||
"float": "double" if size > 4 else "float",
|
||||
"boolean": "boolean",
|
||||
"timestamp": "timestamp",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"uuid": "uuid"
|
||||
}
|
||||
|
||||
return type_mapping.get(field_type, "text")
|
||||
|
||||
def convert_value(self, value: Any, field_type: str) -> Any:
|
||||
"""Convert value to appropriate type for Cassandra"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if field_type == "integer":
|
||||
return int(value)
|
||||
elif field_type == "float":
|
||||
return float(value)
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes')
|
||||
return bool(value)
|
||||
elif field_type == "timestamp":
|
||||
# Handle timestamp conversion if needed
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
except Exception:
|
||||
# Fallback to string conversion
|
||||
return str(value)
|
||||
|
||||
def generate_table_cql(self, keyspace: str, table_name: str, schema: MockRowSchema) -> str:
|
||||
"""Generate CREATE TABLE CQL statement"""
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column definitions
|
||||
columns = ["collection text"] # Collection is always part of table
|
||||
primary_key_fields = []
|
||||
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
cassandra_type = self.get_cassandra_type(field.type, field.size)
|
||||
columns.append(f"{safe_field_name} {cassandra_type}")
|
||||
|
||||
if field.primary:
|
||||
primary_key_fields.append(safe_field_name)
|
||||
|
||||
# Build primary key - collection is always first in partition key
|
||||
if primary_key_fields:
|
||||
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
|
||||
else:
|
||||
# If no primary key defined, use collection and a synthetic id
|
||||
columns.append("synthetic_id uuid")
|
||||
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
|
||||
|
||||
# Create table CQL
|
||||
create_table_cql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
|
||||
{', '.join(columns)},
|
||||
{primary_key}
|
||||
)
|
||||
"""
|
||||
|
||||
return create_table_cql.strip()
|
||||
|
||||
def generate_index_cql(self, keyspace: str, table_name: str, schema: MockRowSchema) -> List[str]:
|
||||
"""Generate CREATE INDEX CQL statements for indexed fields"""
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
index_statements = []
|
||||
|
||||
for field in schema.fields:
|
||||
if field.indexed and not field.primary:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
index_name = f"{safe_table}_{safe_field_name}_idx"
|
||||
create_index_cql = f"""
|
||||
CREATE INDEX IF NOT EXISTS {index_name}
|
||||
ON {safe_keyspace}.{safe_table} ({safe_field_name})
|
||||
"""
|
||||
index_statements.append(create_index_cql.strip())
|
||||
|
||||
return index_statements
|
||||
|
||||
def generate_insert_cql(self, keyspace: str, table_name: str, schema: MockRowSchema,
|
||||
values: Dict[str, Any], collection: str) -> tuple[str, tuple]:
|
||||
"""Generate INSERT CQL statement and values tuple"""
|
||||
safe_keyspace = self.sanitize_name(keyspace)
|
||||
safe_table = self.sanitize_table(table_name)
|
||||
|
||||
# Build column names and values
|
||||
columns = ["collection"]
|
||||
value_list = [collection]
|
||||
placeholders = ["%s"]
|
||||
|
||||
# Check if we need a synthetic ID
|
||||
has_primary_key = any(field.primary for field in schema.fields)
|
||||
if not has_primary_key:
|
||||
import uuid
|
||||
columns.append("synthetic_id")
|
||||
value_list.append(uuid.uuid4())
|
||||
placeholders.append("%s")
|
||||
|
||||
# Process fields
|
||||
for field in schema.fields:
|
||||
safe_field_name = self.sanitize_name(field.name)
|
||||
raw_value = values.get(field.name)
|
||||
|
||||
# Convert value to appropriate type
|
||||
converted_value = self.convert_value(raw_value, field.type)
|
||||
|
||||
columns.append(safe_field_name)
|
||||
value_list.append(converted_value)
|
||||
placeholders.append("%s")
|
||||
|
||||
# Build insert query
|
||||
insert_cql = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
||||
VALUES ({', '.join(placeholders)})
|
||||
"""
|
||||
|
||||
return insert_cql.strip(), tuple(value_list)
|
||||
|
||||
def validate_object_for_storage(self, obj_values: Dict[str, Any], schema: MockRowSchema) -> Dict[str, str]:
|
||||
"""Validate object values for storage, return errors if any"""
|
||||
errors = {}
|
||||
|
||||
# Check for missing required fields
|
||||
for field in schema.fields:
|
||||
if field.required and field.name not in obj_values:
|
||||
errors[field.name] = f"Required field '{field.name}' is missing"
|
||||
|
||||
# Check primary key fields are not None/empty
|
||||
if field.primary and field.name in obj_values:
|
||||
value = obj_values[field.name]
|
||||
if value is None or str(value).strip() == "":
|
||||
errors[field.name] = f"Primary key field '{field.name}' cannot be empty"
|
||||
|
||||
# Check enum constraints
|
||||
if field.enum_values and field.name in obj_values:
|
||||
value = obj_values[field.name]
|
||||
if value and value not in field.enum_values:
|
||||
errors[field.name] = f"Value '{value}' not in allowed enum values: {field.enum_values}"
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class TestCassandraStorageLogic:
|
||||
"""Test cases for Cassandra storage business logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def storage_logic(self):
|
||||
return MockCassandraStorageLogic()
|
||||
|
||||
@pytest.fixture
|
||||
def customer_schema(self):
|
||||
return MockRowSchema(
|
||||
name="customer_records",
|
||||
description="Customer information",
|
||||
fields=[
|
||||
MockField(
|
||||
name="customer_id",
|
||||
type="string",
|
||||
primary=True,
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
MockField(
|
||||
name="name",
|
||||
type="string",
|
||||
required=True
|
||||
),
|
||||
MockField(
|
||||
name="email",
|
||||
type="string",
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
MockField(
|
||||
name="age",
|
||||
type="integer",
|
||||
size=4
|
||||
),
|
||||
MockField(
|
||||
name="status",
|
||||
type="string",
|
||||
indexed=True,
|
||||
enum_values=["active", "inactive", "suspended"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def test_sanitize_name_keyspace(self, storage_logic):
|
||||
"""Test name sanitization for Cassandra keyspaces"""
|
||||
# Test various name patterns
|
||||
assert storage_logic.sanitize_name("simple_name") == "simple_name"
|
||||
assert storage_logic.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert storage_logic.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert storage_logic.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
||||
assert storage_logic.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert storage_logic.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
|
||||
def test_sanitize_table_name(self, storage_logic):
|
||||
"""Test table name sanitization"""
|
||||
# Tables always get o_ prefix
|
||||
assert storage_logic.sanitize_table("simple_name") == "o_simple_name"
|
||||
assert storage_logic.sanitize_table("Name-With-Dashes") == "o_name_with_dashes"
|
||||
assert storage_logic.sanitize_table("name.with.dots") == "o_name_with_dots"
|
||||
assert storage_logic.sanitize_table("123_starts_with_number") == "o_123_starts_with_number"
|
||||
|
||||
def test_get_cassandra_type(self, storage_logic):
|
||||
"""Test field type conversion to Cassandra types"""
|
||||
# Basic type mappings
|
||||
assert storage_logic.get_cassandra_type("string") == "text"
|
||||
assert storage_logic.get_cassandra_type("boolean") == "boolean"
|
||||
assert storage_logic.get_cassandra_type("timestamp") == "timestamp"
|
||||
assert storage_logic.get_cassandra_type("uuid") == "uuid"
|
||||
|
||||
# Integer types with size hints
|
||||
assert storage_logic.get_cassandra_type("integer", size=2) == "int"
|
||||
assert storage_logic.get_cassandra_type("integer", size=8) == "bigint"
|
||||
|
||||
# Float types with size hints
|
||||
assert storage_logic.get_cassandra_type("float", size=2) == "float"
|
||||
assert storage_logic.get_cassandra_type("float", size=8) == "double"
|
||||
|
||||
# Unknown type defaults to text
|
||||
assert storage_logic.get_cassandra_type("unknown_type") == "text"
|
||||
|
||||
def test_convert_value(self, storage_logic):
|
||||
"""Test value conversion for different field types"""
|
||||
# Integer conversions
|
||||
assert storage_logic.convert_value("123", "integer") == 123
|
||||
assert storage_logic.convert_value(123.5, "integer") == 123
|
||||
assert storage_logic.convert_value(None, "integer") is None
|
||||
|
||||
# Float conversions
|
||||
assert storage_logic.convert_value("123.45", "float") == 123.45
|
||||
assert storage_logic.convert_value(123, "float") == 123.0
|
||||
|
||||
# Boolean conversions
|
||||
assert storage_logic.convert_value("true", "boolean") is True
|
||||
assert storage_logic.convert_value("false", "boolean") is False
|
||||
assert storage_logic.convert_value("1", "boolean") is True
|
||||
assert storage_logic.convert_value("0", "boolean") is False
|
||||
assert storage_logic.convert_value("yes", "boolean") is True
|
||||
assert storage_logic.convert_value("no", "boolean") is False
|
||||
|
||||
# String conversions
|
||||
assert storage_logic.convert_value(123, "string") == "123"
|
||||
assert storage_logic.convert_value(True, "string") == "True"
|
||||
|
||||
def test_generate_table_cql(self, storage_logic, customer_schema):
|
||||
"""Test CREATE TABLE CQL generation"""
|
||||
# Act
|
||||
cql = storage_logic.generate_table_cql("test_user", "customer_records", customer_schema)
|
||||
|
||||
# Assert
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in cql
|
||||
assert "collection text" in cql
|
||||
assert "customer_id text" in cql
|
||||
assert "name text" in cql
|
||||
assert "email text" in cql
|
||||
assert "age int" in cql
|
||||
assert "status text" in cql
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in cql
|
||||
|
||||
def test_generate_table_cql_without_primary_key(self, storage_logic):
|
||||
"""Test table creation when no primary key is defined"""
|
||||
# Arrange
|
||||
schema = MockRowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
MockField(name="event_type", type="string"),
|
||||
MockField(name="timestamp", type="timestamp")
|
||||
]
|
||||
)
|
||||
|
||||
# Act
|
||||
cql = storage_logic.generate_table_cql("test_user", "events", schema)
|
||||
|
||||
# Assert
|
||||
assert "synthetic_id uuid" in cql
|
||||
assert "PRIMARY KEY ((collection, synthetic_id))" in cql
|
||||
|
||||
def test_generate_index_cql(self, storage_logic, customer_schema):
|
||||
"""Test CREATE INDEX CQL generation"""
|
||||
# Act
|
||||
index_statements = storage_logic.generate_index_cql("test_user", "customer_records", customer_schema)
|
||||
|
||||
# Assert
|
||||
# Should create indexes for customer_id, email, and status (indexed fields)
|
||||
# But not for customer_id since it's also primary
|
||||
assert len(index_statements) == 2 # email and status
|
||||
|
||||
# Check index creation
|
||||
index_texts = " ".join(index_statements)
|
||||
assert "o_customer_records_email_idx" in index_texts
|
||||
assert "o_customer_records_status_idx" in index_texts
|
||||
assert "CREATE INDEX IF NOT EXISTS" in index_texts
|
||||
assert "customer_id" not in index_texts # Primary keys don't get indexes
|
||||
|
||||
def test_generate_insert_cql(self, storage_logic, customer_schema):
|
||||
"""Test INSERT CQL generation"""
|
||||
# Arrange
|
||||
values = {
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 30,
|
||||
"status": "active"
|
||||
}
|
||||
collection = "test_collection"
|
||||
|
||||
# Act
|
||||
insert_cql, value_tuple = storage_logic.generate_insert_cql(
|
||||
"test_user", "customer_records", customer_schema, values, collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "INSERT INTO test_user.o_customer_records" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert "customer_id" in insert_cql
|
||||
assert "VALUES" in insert_cql
|
||||
assert "%s" in insert_cql
|
||||
|
||||
# Check values tuple
|
||||
assert value_tuple[0] == "test_collection" # collection
|
||||
assert "CUST001" in value_tuple # customer_id
|
||||
assert "John Doe" in value_tuple # name
|
||||
assert 30 in value_tuple # age (converted to int)
|
||||
|
||||
def test_generate_insert_cql_without_primary_key(self, storage_logic):
|
||||
"""Test INSERT CQL generation for schema without primary key"""
|
||||
# Arrange
|
||||
schema = MockRowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[MockField(name="event_type", type="string")]
|
||||
)
|
||||
values = {"event_type": "login"}
|
||||
|
||||
# Act
|
||||
insert_cql, value_tuple = storage_logic.generate_insert_cql(
|
||||
"test_user", "events", schema, values, "test_collection"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "synthetic_id" in insert_cql
|
||||
assert len(value_tuple) == 3 # collection, synthetic_id, event_type
|
||||
# Check that synthetic_id is a UUID (has correct format)
|
||||
import uuid
|
||||
assert isinstance(value_tuple[1], uuid.UUID)
|
||||
|
||||
def test_validate_object_for_storage_success(self, storage_logic, customer_schema):
|
||||
"""Test successful object validation for storage"""
|
||||
# Arrange
|
||||
valid_values = {
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 30,
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
# Act
|
||||
errors = storage_logic.validate_object_for_storage(valid_values, customer_schema)
|
||||
|
||||
# Assert
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_object_missing_required_fields(self, storage_logic, customer_schema):
|
||||
"""Test object validation with missing required fields"""
|
||||
# Arrange
|
||||
invalid_values = {
|
||||
"customer_id": "CUST001",
|
||||
# Missing required 'name' and 'email' fields
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
# Act
|
||||
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
|
||||
|
||||
# Assert
|
||||
assert len(errors) == 2
|
||||
assert "name" in errors
|
||||
assert "email" in errors
|
||||
assert "Required field" in errors["name"]
|
||||
|
||||
def test_validate_object_empty_primary_key(self, storage_logic, customer_schema):
|
||||
"""Test object validation with empty primary key"""
|
||||
# Arrange
|
||||
invalid_values = {
|
||||
"customer_id": "", # Empty primary key
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
# Act
|
||||
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
|
||||
|
||||
# Assert
|
||||
assert len(errors) == 1
|
||||
assert "customer_id" in errors
|
||||
assert "Primary key field" in errors["customer_id"]
|
||||
assert "cannot be empty" in errors["customer_id"]
|
||||
|
||||
def test_validate_object_invalid_enum(self, storage_logic, customer_schema):
|
||||
"""Test object validation with invalid enum value"""
|
||||
# Arrange
|
||||
invalid_values = {
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"status": "invalid_status" # Not in enum
|
||||
}
|
||||
|
||||
# Act
|
||||
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
|
||||
|
||||
# Assert
|
||||
assert len(errors) == 1
|
||||
assert "status" in errors
|
||||
assert "not in allowed enum values" in errors["status"]
|
||||
|
||||
def test_complex_schema_with_all_features(self, storage_logic):
|
||||
"""Test complex schema with all field features"""
|
||||
# Arrange
|
||||
complex_schema = MockRowSchema(
|
||||
name="complex_table",
|
||||
description="Complex table with all features",
|
||||
fields=[
|
||||
MockField(name="id", type="uuid", primary=True, required=True),
|
||||
MockField(name="name", type="string", required=True, indexed=True),
|
||||
MockField(name="count", type="integer", size=8),
|
||||
MockField(name="price", type="float", size=8),
|
||||
MockField(name="active", type="boolean"),
|
||||
MockField(name="created", type="timestamp"),
|
||||
MockField(name="category", type="string", enum_values=["A", "B", "C"], indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Act - Generate table CQL
|
||||
table_cql = storage_logic.generate_table_cql("complex_db", "complex_table", complex_schema)
|
||||
|
||||
# Act - Generate index CQL
|
||||
index_statements = storage_logic.generate_index_cql("complex_db", "complex_table", complex_schema)
|
||||
|
||||
# Assert table creation
|
||||
assert "complex_db.o_complex_table" in table_cql
|
||||
assert "id uuid" in table_cql
|
||||
assert "count bigint" in table_cql # size 8 -> bigint
|
||||
assert "price double" in table_cql # size 8 -> double
|
||||
assert "active boolean" in table_cql
|
||||
assert "created timestamp" in table_cql
|
||||
assert "PRIMARY KEY ((collection, id))" in table_cql
|
||||
|
||||
# Assert index creation (name and category are indexed, but not id since it's primary)
|
||||
assert len(index_statements) == 2
|
||||
index_text = " ".join(index_statements)
|
||||
assert "name_idx" in index_text
|
||||
assert "category_idx" in index_text
|
||||
|
||||
def test_storage_workflow_simulation(self, storage_logic, customer_schema):
|
||||
"""Test complete storage workflow simulation"""
|
||||
keyspace = "customer_db"
|
||||
table_name = "customers"
|
||||
collection = "import_batch_1"
|
||||
|
||||
# Step 1: Generate table creation
|
||||
table_cql = storage_logic.generate_table_cql(keyspace, table_name, customer_schema)
|
||||
assert "CREATE TABLE IF NOT EXISTS" in table_cql
|
||||
|
||||
# Step 2: Generate indexes
|
||||
index_statements = storage_logic.generate_index_cql(keyspace, table_name, customer_schema)
|
||||
assert len(index_statements) > 0
|
||||
|
||||
# Step 3: Validate and insert object
|
||||
customer_data = {
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 35,
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
# Validate
|
||||
errors = storage_logic.validate_object_for_storage(customer_data, customer_schema)
|
||||
assert len(errors) == 0
|
||||
|
||||
# Generate insert
|
||||
insert_cql, values = storage_logic.generate_insert_cql(
|
||||
keyspace, table_name, customer_schema, customer_data, collection
|
||||
)
|
||||
|
||||
assert "customer_db.o_customers" in insert_cql
|
||||
assert values[0] == collection
|
||||
assert "CUST001" in values
|
||||
assert "John Doe" in values
|
||||
387
tests/unit/test_storage/test_doc_embeddings_milvus_storage.py
Normal file
387
tests/unit/test_storage/test_doc_embeddings_milvus_storage.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Tests for Milvus document embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.doc_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import ChunkEmbeddings
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsStorageProcessor:
|
||||
"""Test cases for Milvus document embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-de-storage',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_single_chunk(self, processor):
|
||||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], "Test document content"),
|
||||
([0.4, 0.5, 0.6], "Test document content"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each chunk
|
||||
expected_calls = [
|
||||
# Chunk 1 vectors
|
||||
([0.1, 0.2, 0.3], "This is the first document chunk"),
|
||||
([0.4, 0.5, 0.6], "This is the first document chunk"),
|
||||
# Chunk 2 vectors
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called for empty chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_none_chunk(self, processor):
|
||||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called for None chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
|
||||
"""Test storing document embeddings with mix of valid and invalid chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_chunk = ChunkEmbeddings(
|
||||
chunk=b"Valid document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [valid_chunk, empty_chunk, none_chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify only valid chunk was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Valid document content"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
|
||||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with no vectors",
|
||||
vectors=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
([0.1, 0.2], "Document with mixed dimensions"),
|
||||
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"),
|
||||
([0.7, 0.8, 0.9], "Document with mixed dimensions"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content was properly decoded and inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_large_chunks(self, processor):
|
||||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=large_content.encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify large content was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], large_content
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_whitespace_only_chunk(self, processor):
|
||||
"""Test storing document embeddings with whitespace-only chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b" \n\t ",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify whitespace content was inserted (not filtered out)
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], " \n\t "
|
||||
)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
|
||||
)
|
||||
536
tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py
Normal file
536
tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
"""
|
||||
Tests for Pinecone document embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.doc_embeddings.pinecone.write import Processor
|
||||
from trustgraph.schema import ChunkEmbeddings
|
||||
|
||||
|
||||
class TestPineconeDocEmbeddingsStorageProcessor:
|
||||
"""Test cases for Pinecone document embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-de-storage',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
assert processor.cloud == 'aws'
|
||||
assert processor.region == 'us-east-1'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key',
|
||||
cloud='gcp',
|
||||
region='us-west1'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
assert processor.cloud == 'gcp'
|
||||
assert processor.region == 'us-west1'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_single_chunk(self, processor):
|
||||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
assert mock_index.upsert.call_count == 2
|
||||
|
||||
# Check first vector upsert
|
||||
first_call = mock_index.upsert.call_args_list[0]
|
||||
first_vectors = first_call[1]['vectors']
|
||||
assert len(first_vectors) == 1
|
||||
assert first_vectors[0]['id'] == 'id1'
|
||||
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
|
||||
assert first_vectors[0]['metadata']['doc'] == "Test document content"
|
||||
|
||||
# Check second vector upsert
|
||||
second_call = mock_index.upsert.call_args_list[1]
|
||||
second_vectors = second_call[1]['vectors']
|
||||
assert len(second_vectors) == 1
|
||||
assert second_vectors[0]['id'] == 'id2'
|
||||
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
|
||||
assert second_vectors[0]['metadata']['doc'] == "Test document content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Verify upsert was called for each vector (3 total)
|
||||
assert mock_index.upsert.call_count == 3
|
||||
|
||||
# Verify document content in metadata
|
||||
calls = mock_index.upsert.call_args_list
|
||||
assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
|
||||
assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
|
||||
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_none_chunk(self, processor):
|
||||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for None chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_decoded_chunk(self, processor):
|
||||
"""Test storing document embeddings with chunk that decodes to empty string"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"", # Empty bytes
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty decoded chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
mock_index_3d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
elif name.endswith("-3"):
|
||||
return mock_index_3d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 3
|
||||
mock_index_2d.upsert.assert_called_once()
|
||||
mock_index_4d.upsert.assert_called_once()
|
||||
mock_index_3d.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
|
||||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with no vectors",
|
||||
vectors=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content was properly decoded and stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
|
||||
assert stored_doc == "Document with Unicode: éñ中文🚀"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_large_chunks(self, processor):
|
||||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=large_content.encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify large content was stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
|
||||
assert stored_doc == large_content
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
assert hasattr(args, 'cloud')
|
||||
assert args.cloud == 'aws'
|
||||
assert hasattr(args, 'region')
|
||||
assert args.region == 'us-east-1'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io',
|
||||
'--cloud', 'gcp',
|
||||
'--region', 'us-west1'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
assert args.cloud == 'gcp'
|
||||
assert args.region == 'us-west1'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n"
|
||||
)
|
||||
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,569 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.doc_embeddings.qdrant.write
|
||||
Testing document embeddings storage functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['doc'] == 'test document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per chunk)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both chunks were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First chunk
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['doc'] == 'first document chunk'
|
||||
|
||||
# Second chunk
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['doc'] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunk having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['doc'] == 'multi-vector document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing document embeddings skips empty chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty chunk
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_new_user_new_collection_5'
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_collection
|
||||
|
||||
# Verify upsert was still called after collection creation
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation handles exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection caching with last_collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3'
|
||||
assert processor.last_collection == expected_collection
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that different vector dimensions create different collections"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2], # 2 dimensions
|
||||
[0.3, 0.4, 0.5] # 3 dimensions
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of both collections
|
||||
expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3']
|
||||
actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list]
|
||||
assert actual_calls == expected_collections
|
||||
|
||||
# Should upsert to both collections
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test proper UTF-8 decoding of chunk text"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with UTF-8 encoded text
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'utf8_user'
|
||||
mock_message.metadata.collection = 'utf8_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify chunk.decode was called with 'utf-8'
|
||||
mock_chunk.chunk.decode.assert_called_with('utf-8')
|
||||
|
||||
# Verify the decoded text was stored in payload
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test handling of chunk decode exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with decode error
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'decode_user'
|
||||
mock_message.metadata.collection = 'decode_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
354
tests/unit/test_storage/test_graph_embeddings_milvus_storage.py
Normal file
354
tests/unit/test_storage/test_graph_embeddings_milvus_storage.py
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
"""
|
||||
Tests for Milvus graph embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import Value, EntityEmbeddings
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||
"""Test cases for Milvus graph embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entities with embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity1', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value='literal entity', is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-ge-storage',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_single_entity(self, processor):
|
||||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each entity
|
||||
expected_calls = [
|
||||
# Entity 1 vectors
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity1'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity1'),
|
||||
# Entity 2 vectors
|
||||
([0.7, 0.8, 0.9], 'literal entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called for empty entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_none_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called for None entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor):
|
||||
"""Test storing graph embeddings with mix of valid and invalid entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/valid', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify only valid entity was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], 'http://example.com/valid'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entities_list(self, processor):
|
||||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
|
||||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
([0.1, 0.2], 'http://example.com/entity'),
|
||||
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
|
||||
([0.7, 0.8, 0.9], 'http://example.com/entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_uri_and_literal_entities(self, processor):
|
||||
"""Test storing graph embeddings for both URI and literal entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/uri_entity', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
literal_entity = EntityEmbeddings(
|
||||
entity=Value(value='literal entity text', is_uri=False),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify both entities were inserted
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], 'http://example.com/uri_entity'),
|
||||
([0.4, 0.5, 0.6], 'literal entity text'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
|
||||
)
|
||||
|
|
@ -0,0 +1,460 @@
|
|||
"""
|
||||
Tests for Pinecone graph embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.graph_embeddings.pinecone.write import Processor
|
||||
from trustgraph.schema import EntityEmbeddings, Value
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||
"""Test cases for Pinecone graph embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entity embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value="entity2", is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-ge-storage',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
assert processor.cloud == 'aws'
|
||||
assert processor.region == 'us-east-1'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key',
|
||||
cloud='gcp',
|
||||
region='us-west1'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
assert processor.cloud == 'gcp'
|
||||
assert processor.region == 'us-west1'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_single_entity(self, processor):
|
||||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
assert mock_index.upsert.call_count == 2
|
||||
|
||||
# Check first vector upsert
|
||||
first_call = mock_index.upsert.call_args_list[0]
|
||||
first_vectors = first_call[1]['vectors']
|
||||
assert len(first_vectors) == 1
|
||||
assert first_vectors[0]['id'] == 'id1'
|
||||
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
|
||||
assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
|
||||
# Check second vector upsert
|
||||
second_call = mock_index.upsert.call_args_list[1]
|
||||
second_vectors = second_call[1]['vectors']
|
||||
assert len(second_vectors) == 1
|
||||
assert second_vectors[0]['id'] == 'id2'
|
||||
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
|
||||
assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Verify upsert was called for each vector (3 total)
|
||||
assert mock_index.upsert.call_count == 3
|
||||
|
||||
# Verify entity values in metadata
|
||||
calls = mock_index.upsert.call_args_list
|
||||
assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_none_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for None entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
mock_index_3d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
elif name.endswith("-3"):
|
||||
return mock_index_3d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 3
|
||||
mock_index_2d.upsert.assert_called_once()
|
||||
mock_index_4d.upsert.assert_called_once()
|
||||
mock_index_3d.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entities_list(self, processor):
|
||||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
|
||||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added by parsing empty args
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
assert hasattr(args, 'cloud')
|
||||
assert args.cloud == 'aws'
|
||||
assert hasattr(args, 'region')
|
||||
assert args.region == 'us-east-1'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io',
|
||||
'--cloud', 'gcp',
|
||||
'--region', 'us-west1'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
assert args.cloud == 'gcp'
|
||||
assert args.region == 'us-west1'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Pinecone store.\n"
|
||||
)
|
||||
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.graph_embeddings.qdrant.write
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection creates a new collection when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_test_user_test_collection_512'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_name
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['entity'] == 'test_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection uses existing collection without creating new one"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection_256'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection skips checks when using same collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection_128'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection handles collection creation exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
processor.get_collection(dim=512, user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple entities"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per entity)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both entities were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First entity
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['entity'] == 'entity_one'
|
||||
|
||||
# Second entity
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['entity'] == 'entity_two'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entity having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['entity'] == 'multi_vector_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing graph embeddings skips empty entity values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty entity value
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity.value = None # None value
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty entities
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
328
tests/unit/test_storage/test_objects_cassandra_storage.py
Normal file
328
tests/unit/test_storage/test_objects_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
Unit tests for Cassandra Object Storage Processor
|
||||
|
||||
Tests the business logic of the object storage processor including:
|
||||
- Schema configuration handling
|
||||
- Type conversions
|
||||
- Name sanitization
|
||||
- Table structure generation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageLogic:
|
||||
"""Test business logic without FlowProcessor dependencies"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns (back to original logic)
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
|
||||
def test_get_cassandra_type(self):
|
||||
"""Test field type conversion to Cassandra types"""
|
||||
processor = MagicMock()
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_cassandra_type("string") == "text"
|
||||
assert processor.get_cassandra_type("boolean") == "boolean"
|
||||
assert processor.get_cassandra_type("timestamp") == "timestamp"
|
||||
assert processor.get_cassandra_type("uuid") == "uuid"
|
||||
|
||||
# Integer types with size hints
|
||||
assert processor.get_cassandra_type("integer", size=2) == "int"
|
||||
assert processor.get_cassandra_type("integer", size=8) == "bigint"
|
||||
|
||||
# Float types with size hints
|
||||
assert processor.get_cassandra_type("float", size=2) == "float"
|
||||
assert processor.get_cassandra_type("float", size=8) == "double"
|
||||
|
||||
# Unknown type defaults to text
|
||||
assert processor.get_cassandra_type("unknown_type") == "text"
|
||||
|
||||
def test_convert_value(self):
|
||||
"""Test value conversion for different field types"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Integer conversions
|
||||
assert processor.convert_value("123", "integer") == 123
|
||||
assert processor.convert_value(123.5, "integer") == 123
|
||||
assert processor.convert_value(None, "integer") is None
|
||||
|
||||
# Float conversions
|
||||
assert processor.convert_value("123.45", "float") == 123.45
|
||||
assert processor.convert_value(123, "float") == 123.0
|
||||
|
||||
# Boolean conversions
|
||||
assert processor.convert_value("true", "boolean") is True
|
||||
assert processor.convert_value("false", "boolean") is False
|
||||
assert processor.convert_value("1", "boolean") is True
|
||||
assert processor.convert_value("0", "boolean") is False
|
||||
assert processor.convert_value("yes", "boolean") is True
|
||||
assert processor.convert_value("no", "boolean") is False
|
||||
|
||||
# String conversions
|
||||
assert processor.convert_value(123, "string") == "123"
|
||||
assert processor.convert_value(True, "string") == "True"
|
||||
|
||||
def test_table_creation_cql_generation(self):
|
||||
"""Test CQL generation for table creation"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="customer_records",
|
||||
description="Test customer schema",
|
||||
fields=[
|
||||
Field(
|
||||
name="customer_id",
|
||||
type="string",
|
||||
size=50,
|
||||
primary=True,
|
||||
required=True,
|
||||
indexed=False
|
||||
),
|
||||
Field(
|
||||
name="email",
|
||||
type="string",
|
||||
size=100,
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
size=4,
|
||||
required=False,
|
||||
indexed=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "customer_records", schema)
|
||||
|
||||
# Verify keyspace was ensured (check that it was added to known_keyspaces)
|
||||
assert "test_user" in processor.known_keyspaces
|
||||
|
||||
# Check the CQL that was executed (first call should be table creation)
|
||||
all_calls = processor.session.execute.call_args_list
|
||||
table_creation_cql = all_calls[0][0][0] # First call
|
||||
|
||||
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
|
||||
assert "collection text" in table_creation_cql
|
||||
assert "customer_id text" in table_creation_cql
|
||||
assert "email text" in table_creation_cql
|
||||
assert "age int" in table_creation_cql
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
|
||||
|
||||
def test_table_creation_without_primary_key(self):
|
||||
"""Test table creation when no primary key is defined"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema without primary key
|
||||
schema = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "events", schema)
|
||||
|
||||
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
|
||||
executed_cql = processor.session.execute.call_args[0][0]
|
||||
assert "synthetic_id uuid" in executed_cql
|
||||
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "balance",
|
||||
"type": "float",
|
||||
"size": 8
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
# Note: Field.required always returns False due to Pulsar schema limitations
|
||||
# The actual required value is tracked during schema parsing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_logic(self):
|
||||
"""Test the logic for processing ExtractedObject"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values={"id": "123", "value": "456"},
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
|
||||
|
||||
# Verify insert was executed (keyspace normal, table with o_ prefix)
|
||||
processor.session.execute.assert_called_once()
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert values[0] == "test_collection" # collection value
|
||||
assert values[1] == "123" # id value
|
||||
assert values[2] == 456 # converted integer value
|
||||
|
||||
def test_secondary_index_creation(self):
|
||||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
description="Product catalog",
|
||||
fields=[
|
||||
Field(name="product_id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=30, indexed=True),
|
||||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check index creation calls (table has o_ prefix, fields don't)
|
||||
calls = processor.session.execute.call_args_list
|
||||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
||||
assert len(index_calls) == 2
|
||||
assert any("o_products_category_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
"""
|
||||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestCassandraStorageProcessor:
|
||||
"""Test cases for Cassandra storage processor"""
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password == 'testpass'
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser'
|
||||
)
|
||||
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_with_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called with auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user1',
|
||||
table='collection1',
|
||||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_without_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called without auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user2',
|
||||
table='collection2'
|
||||
)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_reuse_when_same(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion(self, mock_trustgraph):
|
||||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = 'subject1'
|
||||
triple1.p.value = 'predicate1'
|
||||
triple1.o.value = 'object1'
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = 'subject2'
|
||||
triple2.p.value = 'predicate2'
|
||||
triple2.o.value = 'object2'
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify both triples were inserted
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1')
|
||||
mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
|
||||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'cassandra.example.com',
|
||||
'--graph-username', 'testuser',
|
||||
'--graph-password', 'testpass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'cassandra.example.com'
|
||||
assert args.graph_username == 'testuser'
|
||||
assert args.graph_password == 'testpass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.example.com'])
|
||||
|
||||
assert args.graph_host == 'short.example.com'
|
||||
|
||||
@patch('trustgraph.storage.triples.cassandra.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.cassandra.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
|
||||
"""Test table switching when different tables are used in sequence"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
|
||||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = 'subject with spaces & symbols'
|
||||
triple.p.value = 'predicate:with/colons'
|
||||
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
'object with "quotes" and unicode: ñáéíóú'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
436
tests/unit/test_storage/test_triples_falkordb_storage.py
Normal file
436
tests/unit/test_storage/test_triples_falkordb_storage.py
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
"""
|
||||
Tests for FalkorDB triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.falkordb.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestFalkorDBStorageProcessor:
|
||||
"""Test cases for FalkorDB storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb:
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-falkordb-storage',
|
||||
graph_url='falkor://localhost:6379',
|
||||
database='test_db'
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'falkordb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
|
||||
mock_client.select_graph.assert_called_once_with('falkordb')
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
|
||||
def test_processor_initialization_with_custom_params(self, mock_falkordb):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_url='falkor://custom:6379',
|
||||
database='custom_db'
|
||||
)
|
||||
|
||||
assert processor.db == 'custom_db'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
|
||||
mock_client.select_graph.assert_called_once_with('custom_db')
|
||||
|
||||
def test_create_node(self, processor):
|
||||
"""Test node creation"""
|
||||
test_uri = 'http://example.com/node'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_literal(self, processor):
|
||||
"""Test literal creation"""
|
||||
test_value = 'test literal value'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
params={
|
||||
"value": test_value,
|
||||
},
|
||||
)
|
||||
|
||||
def test_relate_node(self, processor):
|
||||
"""Test node-to-node relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
dest_uri = 'http://example.com/dest'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 0
|
||||
mock_result.run_time_ms = 5
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": dest_uri,
|
||||
"uri": pred_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_relate_literal(self, processor):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
literal_value = 'literal destination'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 0
|
||||
mock_result.run_time_ms = 5
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": literal_value,
|
||||
"uri": pred_uri,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_uri_object(self, processor):
|
||||
"""Test storing triple with URI object"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
|
||||
# Create object node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.io.query.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_literal_object(self, processor, mock_message):
|
||||
"""Test storing triple with literal object"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
|
||||
# Create literal object
|
||||
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.io.query.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_multiple_triples(self, processor):
|
||||
"""Test storing multiple triples"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
||||
# Verify first triple operations
|
||||
first_triple_calls = processor.io.query.call_args_list[0:3]
|
||||
assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1"
|
||||
assert first_triple_calls[1][1]["params"]["value"] == "literal object1"
|
||||
assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1"
|
||||
|
||||
# Verify second triple operations
|
||||
second_triple_calls = processor.io.query.call_args_list[3:6]
|
||||
assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2"
|
||||
assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2"
|
||||
assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_empty_list(self, processor):
|
||||
"""Test storing empty triples list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_mixed_objects(self, processor):
|
||||
"""Test storing triples with mixed URI and literal objects"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
||||
# Verify first triple creates literal
|
||||
assert "Literal" in processor.io.query.call_args_list[1][0][0]
|
||||
assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object"
|
||||
|
||||
# Verify second triple creates node
|
||||
assert "Node" in processor.io.query.call_args_list[4][0][0]
|
||||
assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2"
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_url')
|
||||
assert args.graph_url == 'falkor://falkordb:6379'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'falkordb'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-url', 'falkor://custom:6379',
|
||||
'--database', 'custom_db'
|
||||
])
|
||||
|
||||
assert args.graph_url == 'falkor://custom:6379'
|
||||
assert args.database == 'custom_db'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'falkor://short:6379'])
|
||||
|
||||
assert args.graph_url == 'falkor://short:6379'
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.falkordb.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n"
|
||||
)
|
||||
|
||||
def test_create_node_with_special_characters(self, processor):
|
||||
"""Test node creation with special characters in URI"""
|
||||
test_uri = 'http://example.com/node with spaces & symbols'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_literal_with_special_characters(self, processor):
|
||||
"""Test literal creation with special characters"""
|
||||
test_value = 'literal with "quotes" and \n newlines'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
params={
|
||||
"value": test_value,
|
||||
},
|
||||
)
|
||||
441
tests/unit/test_storage/test_triples_memgraph_storage.py
Normal file
441
tests/unit/test_storage/test_triples_memgraph_storage.py
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Tests for Memgraph triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestMemgraphStorageProcessor:
|
||||
"""Test cases for Memgraph storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db:
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-memgraph-storage',
|
||||
graph_host='bolt://localhost:7687',
|
||||
username='test_user',
|
||||
password='test_pass',
|
||||
database='test_db'
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'memgraph'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://memgraph:7687',
|
||||
auth=('memgraph', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='custom_user',
|
||||
password='custom_pass',
|
||||
database='custom_db'
|
||||
)
|
||||
|
||||
assert processor.db == 'custom_db'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('custom_user', 'custom_pass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_indexes_success(self, mock_graph_db):
|
||||
"""Test successful index creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify index creation calls
|
||||
expected_calls = [
|
||||
"CREATE INDEX ON :Node",
|
||||
"CREATE INDEX ON :Node(uri)",
|
||||
"CREATE INDEX ON :Literal",
|
||||
"CREATE INDEX ON :Literal(value)"
|
||||
]
|
||||
|
||||
assert mock_session.run.call_count == len(expected_calls)
|
||||
for i, expected_call in enumerate(expected_calls):
|
||||
actual_call = mock_session.run.call_args_list[i][0][0]
|
||||
assert actual_call == expected_call
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_indexes_with_exceptions(self, mock_graph_db):
|
||||
"""Test index creation with exceptions (should be ignored)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Make all index creation calls raise exceptions
|
||||
mock_session.run.side_effect = Exception("Index already exists")
|
||||
|
||||
# Should not raise an exception
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify all index creation calls were attempted
|
||||
assert mock_session.run.call_count == 4
|
||||
|
||||
def test_create_node(self, processor):
|
||||
"""Test node creation"""
|
||||
test_uri = 'http://example.com/node'
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=test_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_create_literal(self, processor):
|
||||
"""Test literal creation"""
|
||||
test_value = 'test literal value'
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value=test_value,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_relate_node(self, processor):
|
||||
"""Test node-to-node relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
dest_uri = 'http://example.com/dest'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 5
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src_uri, dest=dest_uri, uri=pred_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_relate_literal(self, processor):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
literal_value = 'literal destination'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 5
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src_uri, dest=literal_value, uri=pred_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_create_triple_with_uri_object(self, processor):
|
||||
"""Test triple creation with URI object"""
|
||||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple)
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
|
||||
# Create object node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
for i, (expected_query, expected_params) in enumerate(expected_calls):
|
||||
actual_call = mock_tx.run.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_query
|
||||
assert actual_call[1] == expected_params
|
||||
|
||||
def test_create_triple_with_literal_object(self, processor):
|
||||
"""Test triple creation with literal object"""
|
||||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple)
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
|
||||
# Create literal object
|
||||
("MERGE (n:Literal {value: $value})", {'value': 'literal object'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
for i, (expected_query, expected_params) in enumerate(expected_calls):
|
||||
actual_call = mock_tx.run.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_query
|
||||
assert actual_call[1] == expected_params
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_single_triple(self, processor, mock_message):
|
||||
"""Test storing a single triple"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify session was created with correct database
|
||||
processor.io.session.assert_called_once_with(database=processor.db)
|
||||
|
||||
# Verify execute_write was called once per triple
|
||||
mock_session.execute_write.assert_called_once()
|
||||
|
||||
# Verify the triple was passed to create_triple
|
||||
call_args = mock_session.execute_write.call_args
|
||||
assert call_args[0][0] == processor.create_triple
|
||||
assert call_args[0][1] == mock_message.triples[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_multiple_triples(self, processor):
|
||||
"""Test storing multiple triples"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
# Create message with multiple triples
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify session was called twice (once per triple)
|
||||
assert processor.io.session.call_count == 2
|
||||
|
||||
# Verify execute_write was called once per triple
|
||||
assert mock_session.execute_write.call_count == 2
|
||||
|
||||
# Verify each triple was processed
|
||||
call_args_list = mock_session.execute_write.call_args_list
|
||||
assert call_args_list[0][0][1] == triple1
|
||||
assert call_args_list[1][0][1] == triple2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_empty_list(self, processor):
|
||||
"""Test storing empty triples list"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no session calls were made (no triples to process)
|
||||
processor.io.session.assert_not_called()
|
||||
|
||||
# Verify no execute_write calls were made
|
||||
mock_session.execute_write.assert_not_called()
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'memgraph'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'memgraph'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'custom_user',
|
||||
'--password', 'custom_pass',
|
||||
'--database', 'custom_db'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'custom_user'
|
||||
assert args.password == 'custom_pass'
|
||||
assert args.database == 'custom_db'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.memgraph.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n"
|
||||
)
|
||||
548
tests/unit/test_storage/test_triples_neo4j_storage.py
Normal file
548
tests/unit/test_storage/test_triples_neo4j_storage.py
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
"""
|
||||
Tests for Neo4j triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor
|
||||
|
||||
|
||||
class TestNeo4jStorageProcessor:
|
||||
"""Test cases for Neo4j storage processor"""
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'neo4j'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://neo4j:7687',
|
||||
auth=('neo4j', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='testuser',
|
||||
password='testpass',
|
||||
database='testdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'testdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('testuser', 'testpass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_indexes_success(self, mock_graph_db):
|
||||
"""Test successful index creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify index creation queries were executed
|
||||
expected_calls = [
|
||||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
|
||||
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)"
|
||||
]
|
||||
|
||||
assert mock_session.run.call_count == 3
|
||||
for expected_query in expected_calls:
|
||||
mock_session.run.assert_any_call(expected_query)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_indexes_with_exceptions(self, mock_graph_db):
|
||||
"""Test index creation with exceptions (should be ignored)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Make session.run raise exceptions
|
||||
mock_session.run.side_effect = Exception("Index already exists")
|
||||
|
||||
# Should not raise exception - they should be caught and ignored
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Should have tried to create all 3 indexes despite exceptions
|
||||
assert mock_session.run.call_count == 3
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_node(self, mock_graph_db):
|
||||
"""Test node creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_node
|
||||
processor.create_node("http://example.com/node")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri="http://example.com/node",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_literal(self, mock_graph_db):
|
||||
"""Test literal creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_literal
|
||||
processor.create_literal("literal value")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value="literal value",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_relate_node(self, mock_graph_db):
|
||||
"""Test node-to-node relationship creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test relate_node
|
||||
processor.relate_node(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/object"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="http://example.com/object",
|
||||
uri="http://example.com/predicate",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_relate_literal(self, mock_graph_db):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test relate_literal
|
||||
processor.relate_literal(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"literal value"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal value",
|
||||
uri="http://example.com/predicate",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_triples_with_uri_object(self, mock_graph_db):
|
||||
"""Test handling triples message with URI object"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject and object
|
||||
# Verify relate_node was called
|
||||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/subject", "database_": "neo4j"}
|
||||
),
|
||||
# Object node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/object", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "http://example.com/object",
|
||||
"uri": "http://example.com/predicate",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
for expected_query, expected_params in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_literal_object(self, mock_graph_db):
|
||||
"""Test handling triples message with literal object"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triple with literal object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "literal value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject
|
||||
# Verify create_literal was called for object
|
||||
# Verify relate_literal was called
|
||||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/subject", "database_": "neo4j"}
|
||||
),
|
||||
# Literal creation
|
||||
(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
{"value": "literal value", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "literal value",
|
||||
"uri": "http://example.com/predicate",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
for expected_query, expected_params in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_multiple_triples(self, mock_graph_db):
|
||||
"""Test handling message with multiple triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = "http://example.com/subject1"
|
||||
triple1.p.value = "http://example.com/predicate1"
|
||||
triple1.o.value = "http://example.com/object1"
|
||||
triple1.o.is_uri = True
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = "http://example.com/subject2"
|
||||
triple2.p.value = "http://example.com/predicate2"
|
||||
triple2.o.value = "literal value"
|
||||
triple2.o.is_uri = False
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should have processed both triples
|
||||
# Triple1: 2 nodes + 1 relationship = 3 calls
|
||||
# Triple2: 1 node + 1 literal + 1 relationship = 3 calls
|
||||
# Total: 6 calls
|
||||
assert mock_driver.execute_query.call_count == 6
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_empty_triples(self, mock_graph_db):
|
||||
"""Test handling message with no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should not have made any execute_query calls beyond index creation
|
||||
# Only index creation calls should have been made during initialization
|
||||
mock_driver.execute_query.assert_not_called()
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'neo4j'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph_host', 'bolt://custom:7687',
|
||||
'--username', 'testuser',
|
||||
'--password', 'testpass',
|
||||
'--database', 'testdb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'testuser'
|
||||
assert args.password == 'testpass'
|
||||
assert args.database == 'testdb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.neo4j.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_special_characters(self, mock_graph_db):
|
||||
"""Test handling triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject with spaces"
|
||||
triple.p.value = "http://example.com/predicate:with/symbols"
|
||||
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
|
||||
triple.o.is_uri = False
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was processed with special characters preserved
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri="http://example.com/subject with spaces",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value='literal with "quotes" and unicode: ñáéíóú',
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject with spaces",
|
||||
dest='literal with "quotes" and unicode: ñáéíóú',
|
||||
uri="http://example.com/predicate:with/symbols",
|
||||
database_="neo4j"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue