Release/v1.2 (#457)

* Bump setup.py versions for 1.1

* PoC MCP server (#419)

* Very initial MCP server PoC for TrustGraph

* Put service on port 8000

* Add MCP container and packages to buildout

* Update docs for API/CLI changes in 1.0 (#421)

* Update some API basics for the 0.23/1.0 API change

* Add MCP container push (#425)

* Add command args to the MCP server (#426)

* Host and port parameters

* Added websocket arg

* More docs

* MCP client support (#427)

- MCP client service
- Tool request/response schema
- API gateway support for mcp-tool
- Message translation for tool request & response
- Make mcp-tool using configuration service for information
  about where the MCP services are.

* Feature/react call mcp (#428)

Key Features

  - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes
  - API Enhancement: New mcp_tool method for flow-specific tool invocation
  - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration
  - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities
  - Tool Management: Enhanced CLI for tool configuration and management

Changes

  - Added MCP tool invocation to API with flow-specific integration
  - Implemented ToolClientSpec and ToolClient for tool call handling
  - Updated agent-manager-react to invoke MCP tools with configurable types
  - Enhanced CLI with new commands and improved help text
  - Added comprehensive documentation for new CLI commands
  - Improved tool configuration management

Testing

  - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing
  - Enhanced agent capability to invoke multiple tools simultaneously

* Test suite executed from CI pipeline (#433)

* Test strategy & test cases

* Unit tests

* Integration tests

* Extending test coverage (#434)

* Contract tests

* Testing embeedings

* Agent unit tests

* Knowledge pipeline tests

* Turn on contract tests

* Increase storage test coverage (#435)

* Fixing storage and adding tests

* PR pipeline only runs quick tests

* Empty configuration is returned as empty list, previously was not in response (#436)

* Update config util to take files as well as command-line text (#437)

* Updated CLI invocation and config model for tools and mcp (#438)

* Updated CLI invocation and config model for tools and mcp

* CLI anomalies

* Tweaked the MCP tool implementation for new model

* Update agent implementation to match the new model

* Fix agent tools, now all tested

* Fixed integration tests

* Fix MCP delete tool params

* Update Python deps to 1.2

* Update to enable knowledge extraction using the agent framework (#439)

* Implement KG extraction agent (kg-extract-agent)

* Using ReAct framework (agent-manager-react)
 
* ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure.
 
* Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.

* Migrate from setup.py to pyproject.toml (#440)

* Converted setup.py to pyproject.toml

* Modern package infrastructure as recommended by py docs

* Install missing build deps (#441)

* Install missing build deps (#442)

* Implement logging strategy (#444)

* Logging strategy and convert all prints() to logging invocations

* Fix/startup failure (#445)

* Fix loggin startup problems

* Fix logging startup problems (#446)

* Fix logging startup problems (#447)

* Fixed Mistral OCR to use current API (#448)

* Fixed Mistral OCR to use current API

* Added PDF decoder tests

* Fix Mistral OCR ident to be standard pdf-decoder (#450)

* Fix Mistral OCR ident to be standard pdf-decoder

* Correct test

* Schema structure refactor (#451)

* Write schema refactor spec

* Implemented schema refactor spec

* Structure data mvp (#452)

* Structured data tech spec

* Architecture principles

* New schemas

* Updated schemas and specs

* Object extractor

* Add .coveragerc

* New tests

* Cassandra object storage

* Trying to object extraction working, issues exist

* Validate librarian collection (#453)

* Fix token chunker, broken API invocation (#454)

* Fix token chunker, broken API invocation (#455)

* Knowledge load utility CLI (#456)

* Knowledge loader

* More tests
This commit is contained in:
cybermaggedon 2025-08-18 20:56:09 +01:00 committed by GitHub
parent c85ba197be
commit 89be656990
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
509 changed files with 49632 additions and 5159 deletions

View file

@ -0,0 +1,162 @@
"""
Shared fixtures for storage tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
@pytest.fixture
def base_storage_config():
"""Base configuration for storage processors"""
return {
'taskgroup': AsyncMock(),
'id': 'test-storage-processor'
}
@pytest.fixture
def qdrant_storage_config(base_storage_config):
"""Configuration for Qdrant storage processors"""
return base_storage_config | {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key'
}
@pytest.fixture
def mock_qdrant_client():
"""Mock Qdrant client"""
mock_client = MagicMock()
mock_client.collection_exists.return_value = True
mock_client.create_collection.return_value = None
mock_client.upsert.return_value = None
return mock_client
@pytest.fixture
def mock_uuid():
"""Mock UUID generation"""
mock_uuid = MagicMock()
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
return mock_uuid
# Document embeddings fixtures
@pytest.fixture
def mock_document_embeddings_message():
"""Mock document embeddings message"""
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test document chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3]]
mock_message.chunks = [mock_chunk]
return mock_message
@pytest.fixture
def mock_document_embeddings_multiple_chunks():
"""Mock document embeddings message with multiple chunks"""
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first document chunk'
mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second document chunk'
mock_chunk2.vectors = [[0.3, 0.4]]
mock_message.chunks = [mock_chunk1, mock_chunk2]
return mock_message
@pytest.fixture
def mock_document_embeddings_multiple_vectors():
"""Mock document embeddings message with multiple vectors per chunk"""
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
mock_chunk.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.chunks = [mock_chunk]
return mock_message
@pytest.fixture
def mock_document_embeddings_empty_chunk():
"""Mock document embeddings message with empty chunk"""
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = "" # Empty string
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
return mock_message
# Graph embeddings fixtures
@pytest.fixture
def mock_graph_embeddings_message():
"""Mock graph embeddings message"""
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]]
mock_message.entities = [mock_entity]
return mock_message
@pytest.fixture
def mock_graph_embeddings_multiple_entities():
"""Mock graph embeddings message with multiple entities"""
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
mock_entity1.entity.value = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity2 = MagicMock()
mock_entity2.entity.value = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity1, mock_entity2]
return mock_message
@pytest.fixture
def mock_graph_embeddings_empty_entity():
"""Mock graph embeddings message with empty entity"""
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_entity = MagicMock()
mock_entity.entity.value = "" # Empty string
mock_entity.vectors = [[0.1, 0.2]]
mock_message.entities = [mock_entity]
return mock_message

View file

@ -0,0 +1,576 @@
"""
Standalone unit tests for Cassandra Storage Logic
Tests core Cassandra storage logic without requiring full package imports.
This focuses on testing the business logic that would be used by the
Cassandra object storage processor components.
"""
import pytest
import json
import re
from unittest.mock import Mock
from typing import Dict, Any, List
class MockField:
"""Mock implementation of Field for testing"""
def __init__(self, name: str, type: str, primary: bool = False,
required: bool = False, indexed: bool = False,
enum_values: List[str] = None, size: int = 0):
self.name = name
self.type = type
self.primary = primary
self.required = required
self.indexed = indexed
self.enum_values = enum_values or []
self.size = size
class MockRowSchema:
"""Mock implementation of RowSchema for testing"""
def __init__(self, name: str, description: str, fields: List[MockField]):
self.name = name
self.description = description
self.fields = fields
class MockCassandraStorageLogic:
"""Mock implementation of Cassandra storage logic for testing"""
def __init__(self):
self.known_keyspaces = set()
self.known_tables = {} # keyspace -> set of table names
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility (keyspaces)"""
# Replace non-alphanumeric characters with underscore
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize table names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Always prefix tables with o_
safe_name = 'o_' + safe_name
return safe_name.lower()
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
"""Convert schema field type to Cassandra type"""
# Handle None size
if size is None:
size = 0
type_mapping = {
"string": "text",
"integer": "bigint" if size > 4 else "int",
"float": "double" if size > 4 else "float",
"boolean": "boolean",
"timestamp": "timestamp",
"date": "date",
"time": "time",
"uuid": "uuid"
}
return type_mapping.get(field_type, "text")
def convert_value(self, value: Any, field_type: str) -> Any:
"""Convert value to appropriate type for Cassandra"""
if value is None:
return None
try:
if field_type == "integer":
return int(value)
elif field_type == "float":
return float(value)
elif field_type == "boolean":
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes')
return bool(value)
elif field_type == "timestamp":
# Handle timestamp conversion if needed
return value
else:
return str(value)
except Exception:
# Fallback to string conversion
return str(value)
def generate_table_cql(self, keyspace: str, table_name: str, schema: MockRowSchema) -> str:
"""Generate CREATE TABLE CQL statement"""
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column definitions
columns = ["collection text"] # Collection is always part of table
primary_key_fields = []
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
cassandra_type = self.get_cassandra_type(field.type, field.size)
columns.append(f"{safe_field_name} {cassandra_type}")
if field.primary:
primary_key_fields.append(safe_field_name)
# Build primary key - collection is always first in partition key
if primary_key_fields:
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
else:
# If no primary key defined, use collection and a synthetic id
columns.append("synthetic_id uuid")
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
# Create table CQL
create_table_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
{', '.join(columns)},
{primary_key}
)
"""
return create_table_cql.strip()
def generate_index_cql(self, keyspace: str, table_name: str, schema: MockRowSchema) -> List[str]:
"""Generate CREATE INDEX CQL statements for indexed fields"""
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
index_statements = []
for field in schema.fields:
if field.indexed and not field.primary:
safe_field_name = self.sanitize_name(field.name)
index_name = f"{safe_table}_{safe_field_name}_idx"
create_index_cql = f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {safe_keyspace}.{safe_table} ({safe_field_name})
"""
index_statements.append(create_index_cql.strip())
return index_statements
def generate_insert_cql(self, keyspace: str, table_name: str, schema: MockRowSchema,
values: Dict[str, Any], collection: str) -> tuple[str, tuple]:
"""Generate INSERT CQL statement and values tuple"""
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column names and values
columns = ["collection"]
value_list = [collection]
placeholders = ["%s"]
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
value_list.append(uuid.uuid4())
placeholders.append("%s")
# Process fields
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = values.get(field.name)
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name)
value_list.append(converted_value)
placeholders.append("%s")
# Build insert query
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
return insert_cql.strip(), tuple(value_list)
def validate_object_for_storage(self, obj_values: Dict[str, Any], schema: MockRowSchema) -> Dict[str, str]:
"""Validate object values for storage, return errors if any"""
errors = {}
# Check for missing required fields
for field in schema.fields:
if field.required and field.name not in obj_values:
errors[field.name] = f"Required field '{field.name}' is missing"
# Check primary key fields are not None/empty
if field.primary and field.name in obj_values:
value = obj_values[field.name]
if value is None or str(value).strip() == "":
errors[field.name] = f"Primary key field '{field.name}' cannot be empty"
# Check enum constraints
if field.enum_values and field.name in obj_values:
value = obj_values[field.name]
if value and value not in field.enum_values:
errors[field.name] = f"Value '{value}' not in allowed enum values: {field.enum_values}"
return errors
class TestCassandraStorageLogic:
"""Test cases for Cassandra storage business logic"""
@pytest.fixture
def storage_logic(self):
return MockCassandraStorageLogic()
@pytest.fixture
def customer_schema(self):
return MockRowSchema(
name="customer_records",
description="Customer information",
fields=[
MockField(
name="customer_id",
type="string",
primary=True,
required=True,
indexed=True
),
MockField(
name="name",
type="string",
required=True
),
MockField(
name="email",
type="string",
required=True,
indexed=True
),
MockField(
name="age",
type="integer",
size=4
),
MockField(
name="status",
type="string",
indexed=True,
enum_values=["active", "inactive", "suspended"]
)
]
)
def test_sanitize_name_keyspace(self, storage_logic):
"""Test name sanitization for Cassandra keyspaces"""
# Test various name patterns
assert storage_logic.sanitize_name("simple_name") == "simple_name"
assert storage_logic.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert storage_logic.sanitize_name("name.with.dots") == "name_with_dots"
assert storage_logic.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert storage_logic.sanitize_name("name with spaces") == "name_with_spaces"
assert storage_logic.sanitize_name("special!@#$%^chars") == "special______chars"
def test_sanitize_table_name(self, storage_logic):
"""Test table name sanitization"""
# Tables always get o_ prefix
assert storage_logic.sanitize_table("simple_name") == "o_simple_name"
assert storage_logic.sanitize_table("Name-With-Dashes") == "o_name_with_dashes"
assert storage_logic.sanitize_table("name.with.dots") == "o_name_with_dots"
assert storage_logic.sanitize_table("123_starts_with_number") == "o_123_starts_with_number"
def test_get_cassandra_type(self, storage_logic):
"""Test field type conversion to Cassandra types"""
# Basic type mappings
assert storage_logic.get_cassandra_type("string") == "text"
assert storage_logic.get_cassandra_type("boolean") == "boolean"
assert storage_logic.get_cassandra_type("timestamp") == "timestamp"
assert storage_logic.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert storage_logic.get_cassandra_type("integer", size=2) == "int"
assert storage_logic.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert storage_logic.get_cassandra_type("float", size=2) == "float"
assert storage_logic.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert storage_logic.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self, storage_logic):
"""Test value conversion for different field types"""
# Integer conversions
assert storage_logic.convert_value("123", "integer") == 123
assert storage_logic.convert_value(123.5, "integer") == 123
assert storage_logic.convert_value(None, "integer") is None
# Float conversions
assert storage_logic.convert_value("123.45", "float") == 123.45
assert storage_logic.convert_value(123, "float") == 123.0
# Boolean conversions
assert storage_logic.convert_value("true", "boolean") is True
assert storage_logic.convert_value("false", "boolean") is False
assert storage_logic.convert_value("1", "boolean") is True
assert storage_logic.convert_value("0", "boolean") is False
assert storage_logic.convert_value("yes", "boolean") is True
assert storage_logic.convert_value("no", "boolean") is False
# String conversions
assert storage_logic.convert_value(123, "string") == "123"
assert storage_logic.convert_value(True, "string") == "True"
def test_generate_table_cql(self, storage_logic, customer_schema):
"""Test CREATE TABLE CQL generation"""
# Act
cql = storage_logic.generate_table_cql("test_user", "customer_records", customer_schema)
# Assert
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in cql
assert "collection text" in cql
assert "customer_id text" in cql
assert "name text" in cql
assert "email text" in cql
assert "age int" in cql
assert "status text" in cql
assert "PRIMARY KEY ((collection, customer_id))" in cql
def test_generate_table_cql_without_primary_key(self, storage_logic):
"""Test table creation when no primary key is defined"""
# Arrange
schema = MockRowSchema(
name="events",
description="Event log",
fields=[
MockField(name="event_type", type="string"),
MockField(name="timestamp", type="timestamp")
]
)
# Act
cql = storage_logic.generate_table_cql("test_user", "events", schema)
# Assert
assert "synthetic_id uuid" in cql
assert "PRIMARY KEY ((collection, synthetic_id))" in cql
def test_generate_index_cql(self, storage_logic, customer_schema):
"""Test CREATE INDEX CQL generation"""
# Act
index_statements = storage_logic.generate_index_cql("test_user", "customer_records", customer_schema)
# Assert
# Should create indexes for customer_id, email, and status (indexed fields)
# But not for customer_id since it's also primary
assert len(index_statements) == 2 # email and status
# Check index creation
index_texts = " ".join(index_statements)
assert "o_customer_records_email_idx" in index_texts
assert "o_customer_records_status_idx" in index_texts
assert "CREATE INDEX IF NOT EXISTS" in index_texts
assert "customer_id" not in index_texts # Primary keys don't get indexes
def test_generate_insert_cql(self, storage_logic, customer_schema):
"""Test INSERT CQL generation"""
# Arrange
values = {
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": 30,
"status": "active"
}
collection = "test_collection"
# Act
insert_cql, value_tuple = storage_logic.generate_insert_cql(
"test_user", "customer_records", customer_schema, values, collection
)
# Assert
assert "INSERT INTO test_user.o_customer_records" in insert_cql
assert "collection" in insert_cql
assert "customer_id" in insert_cql
assert "VALUES" in insert_cql
assert "%s" in insert_cql
# Check values tuple
assert value_tuple[0] == "test_collection" # collection
assert "CUST001" in value_tuple # customer_id
assert "John Doe" in value_tuple # name
assert 30 in value_tuple # age (converted to int)
def test_generate_insert_cql_without_primary_key(self, storage_logic):
"""Test INSERT CQL generation for schema without primary key"""
# Arrange
schema = MockRowSchema(
name="events",
description="Event log",
fields=[MockField(name="event_type", type="string")]
)
values = {"event_type": "login"}
# Act
insert_cql, value_tuple = storage_logic.generate_insert_cql(
"test_user", "events", schema, values, "test_collection"
)
# Assert
assert "synthetic_id" in insert_cql
assert len(value_tuple) == 3 # collection, synthetic_id, event_type
# Check that synthetic_id is a UUID (has correct format)
import uuid
assert isinstance(value_tuple[1], uuid.UUID)
def test_validate_object_for_storage_success(self, storage_logic, customer_schema):
"""Test successful object validation for storage"""
# Arrange
valid_values = {
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": 30,
"status": "active"
}
# Act
errors = storage_logic.validate_object_for_storage(valid_values, customer_schema)
# Assert
assert len(errors) == 0
def test_validate_object_missing_required_fields(self, storage_logic, customer_schema):
"""Test object validation with missing required fields"""
# Arrange
invalid_values = {
"customer_id": "CUST001",
# Missing required 'name' and 'email' fields
"status": "active"
}
# Act
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
# Assert
assert len(errors) == 2
assert "name" in errors
assert "email" in errors
assert "Required field" in errors["name"]
def test_validate_object_empty_primary_key(self, storage_logic, customer_schema):
"""Test object validation with empty primary key"""
# Arrange
invalid_values = {
"customer_id": "", # Empty primary key
"name": "John Doe",
"email": "john@example.com",
"status": "active"
}
# Act
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
# Assert
assert len(errors) == 1
assert "customer_id" in errors
assert "Primary key field" in errors["customer_id"]
assert "cannot be empty" in errors["customer_id"]
def test_validate_object_invalid_enum(self, storage_logic, customer_schema):
"""Test object validation with invalid enum value"""
# Arrange
invalid_values = {
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"status": "invalid_status" # Not in enum
}
# Act
errors = storage_logic.validate_object_for_storage(invalid_values, customer_schema)
# Assert
assert len(errors) == 1
assert "status" in errors
assert "not in allowed enum values" in errors["status"]
def test_complex_schema_with_all_features(self, storage_logic):
"""Test complex schema with all field features"""
# Arrange
complex_schema = MockRowSchema(
name="complex_table",
description="Complex table with all features",
fields=[
MockField(name="id", type="uuid", primary=True, required=True),
MockField(name="name", type="string", required=True, indexed=True),
MockField(name="count", type="integer", size=8),
MockField(name="price", type="float", size=8),
MockField(name="active", type="boolean"),
MockField(name="created", type="timestamp"),
MockField(name="category", type="string", enum_values=["A", "B", "C"], indexed=True)
]
)
# Act - Generate table CQL
table_cql = storage_logic.generate_table_cql("complex_db", "complex_table", complex_schema)
# Act - Generate index CQL
index_statements = storage_logic.generate_index_cql("complex_db", "complex_table", complex_schema)
# Assert table creation
assert "complex_db.o_complex_table" in table_cql
assert "id uuid" in table_cql
assert "count bigint" in table_cql # size 8 -> bigint
assert "price double" in table_cql # size 8 -> double
assert "active boolean" in table_cql
assert "created timestamp" in table_cql
assert "PRIMARY KEY ((collection, id))" in table_cql
# Assert index creation (name and category are indexed, but not id since it's primary)
assert len(index_statements) == 2
index_text = " ".join(index_statements)
assert "name_idx" in index_text
assert "category_idx" in index_text
def test_storage_workflow_simulation(self, storage_logic, customer_schema):
"""Test complete storage workflow simulation"""
keyspace = "customer_db"
table_name = "customers"
collection = "import_batch_1"
# Step 1: Generate table creation
table_cql = storage_logic.generate_table_cql(keyspace, table_name, customer_schema)
assert "CREATE TABLE IF NOT EXISTS" in table_cql
# Step 2: Generate indexes
index_statements = storage_logic.generate_index_cql(keyspace, table_name, customer_schema)
assert len(index_statements) > 0
# Step 3: Validate and insert object
customer_data = {
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": 35,
"status": "active"
}
# Validate
errors = storage_logic.validate_object_for_storage(customer_data, customer_schema)
assert len(errors) == 0
# Generate insert
insert_cql, values = storage_logic.generate_insert_cql(
keyspace, table_name, customer_schema, customer_data, collection
)
assert "customer_db.o_customers" in insert_cql
assert values[0] == collection
assert "CUST001" in values
assert "John Doe" in values

View file

@ -0,0 +1,387 @@
"""
Tests for Milvus document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.doc_embeddings.milvus.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestMilvusDocEmbeddingsStorageProcessor:
"""Test cases for Milvus document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors:
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-de-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], "Test document content"),
([0.4, 0.5, 0.6], "Test document content"),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk
expected_calls = [
# Chunk 1 vectors
([0.1, 0.2, 0.3], "This is the first document chunk"),
([0.4, 0.5, 0.6], "This is the first document chunk"),
# Chunk 2 vectors
([0.7, 0.8, 0.9], "This is the second document chunk"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for None chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
"""Test storing document embeddings with mix of valid and invalid chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
chunk=b"Valid document content",
vectors=[[0.1, 0.2, 0.3]]
)
empty_chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.4, 0.5, 0.6]]
)
none_chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [valid_chunk, empty_chunk, none_chunk]
await processor.store_document_embeddings(message)
# Verify only valid chunk was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Valid document content"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], "Document with mixed dimensions"),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"),
([0.7, 0.8, 0.9], "Document with mixed dimensions"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify large content was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], large_content
)
@pytest.mark.asyncio
async def test_store_document_embeddings_whitespace_only_chunk(self, processor):
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify whitespace content was inserted (not filtered out)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], " \n\t "
)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,536 @@
"""
Tests for Pinecone document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.doc_embeddings.pinecone.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestPineconeDocEmbeddingsStorageProcessor:
"""Test cases for Pinecone document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-de-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
# Verify index name and operations
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['doc'] == "Test document content"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['doc'] == "Test document content"
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify document content in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_decoded_chunk(self, processor):
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == "Document with Unicode: éñ中文🚀"
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify large content was stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == large_content
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,569 @@
"""
Unit tests for trustgraph.storage.doc_embeddings.qdrant.write
Testing document embeddings storage functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant document embeddings storage functionality"""
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
# No store_uri or api_key provided - should use defaults
}
# Act
processor = Processor(**config)
# Assert
# Verify QdrantClient was created with default URI and None API key
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with basic message"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with chunks and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test document chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Verify collection existence was checked
expected_collection = 'd_test_user_test_collection_3'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == expected_collection
assert len(upsert_call_args[1]['points']) == 1
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['doc'] == 'test document chunk'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple chunks"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with multiple chunks
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first document chunk'
mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second document chunk'
mock_chunk2.vectors = [[0.3, 0.4]]
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called twice (once per chunk)
assert mock_qdrant_instance.upsert.call_count == 2
# Verify both chunks were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
# First chunk
first_call = upsert_calls[0]
first_point = first_call[1]['points'][0]
assert first_point.vector == [0.1, 0.2]
assert first_point.payload['doc'] == 'first document chunk'
# Second chunk
second_call = upsert_calls[1]
second_point = second_call[1]['points'][0]
assert second_point.vector == [0.3, 0.4]
assert second_point.payload['doc'] == 'second document chunk'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple vectors per chunk"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with chunk having multiple vectors
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
mock_chunk.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['doc'] == 'multi-vector document chunk'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
"""Test storing document embeddings skips empty chunks"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with empty chunk
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should not call upsert for empty chunks
mock_qdrant_instance.upsert.assert_not_called()
mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
"""Test collection creation when it doesn't exist"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
expected_collection = 'd_new_user_new_collection_5'
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_collection
# Verify upsert was still called after collection creation
mock_qdrant_instance.upsert.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test collection creation handles exceptions"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
"""Test collection caching with last_collection"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first chunk'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second chunk'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection_3'
assert processor.last_collection == expected_collection
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
# But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
"""Test that different vector dimensions create different collections"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should check existence of both collections
expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3']
actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list]
assert actual_calls == expected_collections
# Should upsert to both collections
assert mock_qdrant_instance.upsert.call_count == 2
upsert_calls = mock_qdrant_instance.upsert.call_args_list
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
# Verify processor-specific arguments were added
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test proper UTF-8 decoding of chunk text"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with UTF-8 encoded text
mock_message = MagicMock()
mock_message.metadata.user = 'utf8_user'
mock_message.metadata.collection = 'utf8_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Verify chunk.decode was called with 'utf-8'
mock_chunk.chunk.decode.assert_called_with('utf-8')
# Verify the decoded text was stored in payload
upsert_call_args = mock_qdrant_instance.upsert.call_args
point = upsert_call_args[1]['points'][0]
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
"""Test handling of chunk decode exceptions"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with decode error
mock_message = MagicMock()
mock_message.metadata.user = 'decode_user'
mock_message.metadata.collection = 'decode_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(UnicodeDecodeError):
await processor.store_document_embeddings(mock_message)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,354 @@
"""
Tests for Milvus graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.graph_embeddings.milvus.write import Processor
from trustgraph.schema import Value, EntityEmbeddings
class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test cases for Milvus graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Value(value='http://example.com/entity1', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value='literal entity', is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors:
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-ge-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/entity'),
([0.4, 0.5, 0.6], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity
expected_calls = [
# Entity 1 vectors
([0.1, 0.2, 0.3], 'http://example.com/entity1'),
([0.4, 0.5, 0.6], 'http://example.com/entity1'),
# Entity 2 vectors
([0.7, 0.8, 0.9], 'literal entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor):
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
entity=Value(value='http://example.com/valid', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
empty_entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
none_entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
# Verify only valid entity was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], 'http://example.com/valid'
)
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], 'http://example.com/entity'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
([0.7, 0.8, 0.9], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_uri_and_literal_entities(self, processor):
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
entity=Value(value='http://example.com/uri_entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
literal_entity = EntityEmbeddings(
entity=Value(value='literal entity text', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
# Verify both entities were inserted
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/uri_entity'),
([0.4, 0.5, 0.6], 'literal entity text'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,460 @@
"""
Tests for Pinecone graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.graph_embeddings.pinecone.write import Processor
from trustgraph.schema import EntityEmbeddings, Value
class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test cases for Pinecone graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings
entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-ge-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_graph_embeddings(message)
# Verify index name and operations
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify entity values in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_graph_embeddings(message)
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added by parsing empty args
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,428 @@
"""
Unit tests for trustgraph.storage.graph_embeddings.qdrant.write
Starting small with a single test to verify basic functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant graph embeddings storage functionality"""
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection creates a new collection when it doesn't exist"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
# Assert
expected_name = 't_test_user_test_collection_512'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_name
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with basic message"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid-123'
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with entities and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.entities = [mock_entity]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Verify collection existence was checked
expected_collection = 't_test_user_test_collection_3'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == expected_collection
assert len(upsert_call_args[1]['points']) == 1
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['entity'] == 'test_entity'
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection uses existing collection without creating new one"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
# Assert
expected_name = 't_existing_user_existing_collection_256'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check was performed
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
# Verify create_collection was NOT called
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection skips checks when using same collection"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# First call
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
# Act - Second call with same parameters
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
# Assert
expected_name = 't_cache_user_cache_collection_128'
assert collection_name1 == expected_name
assert collection_name2 == expected_name
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test get_collection handles collection creation exceptions"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
processor.get_collection(dim=512, user='error_user', collection='error_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple entities"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with multiple entities
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
mock_entity1.entity.value = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity2 = MagicMock()
mock_entity2.entity.value = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity1, mock_entity2]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Should be called twice (once per entity)
assert mock_qdrant_instance.upsert.call_count == 2
# Verify both entities were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
# First entity
first_call = upsert_calls[0]
first_point = first_call[1]['points'][0]
assert first_point.vector == [0.1, 0.2]
assert first_point.payload['entity'] == 'entity_one'
# Second entity
second_call = upsert_calls[1]
second_point = second_call[1]['points'][0]
assert second_point.vector == [0.3, 0.4]
assert second_point.payload['entity'] == 'entity_two'
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple vectors per entity"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with entity having multiple vectors
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'multi_vector_entity'
mock_entity.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.entities = [mock_entity]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['entity'] == 'multi_vector_entity'
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
"""Test storing graph embeddings skips empty entity values"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
# Create mock message with empty entity value
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]]
mock_entity_none = MagicMock()
mock_entity_none.entity.value = None # None value
mock_entity_none.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity_empty, mock_entity_none]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Should not call upsert for empty entities
mock_qdrant_instance.upsert.assert_not_called()
mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
# No store_uri or api_key provided - should use defaults
}
# Act
processor = Processor(**config)
# Assert
# Verify QdrantClient was created with default URI and None API key
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
# Verify processor-specific arguments were added
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,328 @@
"""
Unit tests for Cassandra Object Storage Processor
Tests the business logic of the object storage processor including:
- Schema configuration handling
- Type conversions
- Name sanitization
- Table structure generation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestObjectsCassandraStorageLogic:
"""Test business logic without FlowProcessor dependencies"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns (back to original logic)
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
def test_get_cassandra_type(self):
"""Test field type conversion to Cassandra types"""
processor = MagicMock()
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_cassandra_type("string") == "text"
assert processor.get_cassandra_type("boolean") == "boolean"
assert processor.get_cassandra_type("timestamp") == "timestamp"
assert processor.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert processor.get_cassandra_type("integer", size=2) == "int"
assert processor.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert processor.get_cassandra_type("float", size=2) == "float"
assert processor.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert processor.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self):
"""Test value conversion for different field types"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Integer conversions
assert processor.convert_value("123", "integer") == 123
assert processor.convert_value(123.5, "integer") == 123
assert processor.convert_value(None, "integer") is None
# Float conversions
assert processor.convert_value("123.45", "float") == 123.45
assert processor.convert_value(123, "float") == 123.0
# Boolean conversions
assert processor.convert_value("true", "boolean") is True
assert processor.convert_value("false", "boolean") is False
assert processor.convert_value("1", "boolean") is True
assert processor.convert_value("0", "boolean") is False
assert processor.convert_value("yes", "boolean") is True
assert processor.convert_value("no", "boolean") is False
# String conversions
assert processor.convert_value(123, "string") == "123"
assert processor.convert_value(True, "string") == "True"
def test_table_creation_cql_generation(self):
"""Test CQL generation for table creation"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="customer_records",
description="Test customer schema",
fields=[
Field(
name="customer_id",
type="string",
size=50,
primary=True,
required=True,
indexed=False
),
Field(
name="email",
type="string",
size=100,
required=True,
indexed=True
),
Field(
name="age",
type="integer",
size=4,
required=False,
indexed=False
)
]
)
# Call ensure_table
processor.ensure_table("test_user", "customer_records", schema)
# Verify keyspace was ensured (check that it was added to known_keyspaces)
assert "test_user" in processor.known_keyspaces
# Check the CQL that was executed (first call should be table creation)
all_calls = processor.session.execute.call_args_list
table_creation_cql = all_calls[0][0][0] # First call
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
assert "collection text" in table_creation_cql
assert "customer_id text" in table_creation_cql
assert "email text" in table_creation_cql
assert "age int" in table_creation_cql
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
def test_table_creation_without_primary_key(self):
"""Test table creation when no primary key is defined"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema without primary key
schema = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Call ensure_table
processor.ensure_table("test_user", "events", schema)
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
executed_cql = processor.session.execute.call_args[0][0]
assert "synthetic_id uuid" in executed_cql
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "balance",
"type": "float",
"size": 8
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
# Note: Field.required always returns False due to Pulsar schema limitations
# The actual required value is tracked during schema parsing
@pytest.mark.asyncio
async def test_object_processing_logic(self):
"""Test the logic for processing ExtractedObject"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values={"id": "123", "value": "456"},
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
# Verify insert was executed (keyspace normal, table with o_ prefix)
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value
assert values[2] == 456 # converted integer value
def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
description="Product catalog",
fields=[
Field(name="product_id", type="string", size=50, primary=True),
Field(name="category", type="string", size=30, indexed=True),
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
# Check index creation calls (table has o_ prefix, fields don't)
calls = processor.session.execute.call_args_list
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls)

View file

@ -0,0 +1,373 @@
"""
Tests for Cassandra triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.cassandra.write import Processor
from trustgraph.schema import Value, Triple
class TestCassandraStorageProcessor:
"""Test cases for Cassandra storage processor"""
def test_processor_initialization_with_defaults(self):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
assert processor.graph_host == ['localhost']
assert processor.username is None
assert processor.password is None
assert processor.table is None
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
id='custom-storage',
graph_host='cassandra.example.com',
graph_username='testuser',
graph_password='testpass'
)
assert processor.graph_host == ['cassandra.example.com']
assert processor.username == 'testuser'
assert processor.password == 'testpass'
assert processor.table is None
def test_processor_initialization_with_partial_auth(self):
"""Test processor initialization with only username (no password)"""
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
graph_username='testuser'
)
assert processor.username == 'testuser'
assert processor.password is None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_table_switching_with_auth(self, mock_trustgraph):
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(
taskgroup=taskgroup_mock,
graph_username='testuser',
graph_password='testpass'
)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify TrustGraph was called with auth parameters
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='user1',
table='collection1',
username='testuser',
password='testpass'
)
assert processor.table == ('user1', 'collection1')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_table_switching_without_auth(self, mock_trustgraph):
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify TrustGraph was called without auth parameters
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='user2',
table='collection2'
)
assert processor.table == ('user2', 'collection2')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_table_reuse_when_same(self, mock_trustgraph):
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_triple_insertion(self, mock_trustgraph):
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
triple1 = MagicMock()
triple1.s.value = 'subject1'
triple1.p.value = 'predicate1'
triple1.o.value = 'object1'
triple2 = MagicMock()
triple2.s.value = 'subject2'
triple2.p.value = 'predicate2'
triple2.o.value = 'object2'
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Verify both triples were inserted
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1')
mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
"""Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock()
mock_trustgraph.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once_with(parser)
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'localhost'
assert hasattr(args, 'graph_username')
assert args.graph_username is None
assert hasattr(args, 'graph_password')
assert args.graph_password is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'cassandra.example.com',
'--graph-username', 'testuser',
'--graph-password', 'testpass'
])
assert args.graph_host == 'cassandra.example.com'
assert args.graph_username == 'testuser'
assert args.graph_password == 'testpass'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'short.example.com'])
assert args.graph_host == 'short.example.com'
@patch('trustgraph.storage.triples.cassandra.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.cassandra.write import run, default_ident
run()
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
"""Test table switching when different tables are used in sequence"""
taskgroup_mock = MagicMock()
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
processor = Processor(taskgroup=taskgroup_mock)
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
assert processor.table == ('user1', 'collection1')
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
assert processor.table == ('user2', 'collection2')
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
assert mock_trustgraph.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
triple = MagicMock()
triple.s.value = 'subject with spaces & symbols'
triple.p.value = 'predicate:with/colons'
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
'subject with spaces & symbols',
'predicate:with/colons',
'object with "quotes" and unicode: ñáéíóú'
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
"""Test that table remains unchanged when TrustGraph creation fails"""
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_trustgraph.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
assert processor.tg is None

View file

@ -0,0 +1,436 @@
"""
Tests for FalkorDB triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.falkordb.write import Processor
from trustgraph.schema import Value, Triple
class TestFalkorDBStorageProcessor:
"""Test cases for FalkorDB storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb:
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
return Processor(
taskgroup=MagicMock(),
id='test-falkordb-storage',
graph_url='falkor://localhost:6379',
database='test_db'
)
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'falkordb'
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
mock_client.select_graph.assert_called_once_with('falkordb')
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_custom_params(self, mock_falkordb):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(
taskgroup=taskgroup_mock,
graph_url='falkor://custom:6379',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
mock_client.select_graph.assert_called_once_with('custom_db')
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
},
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
},
)
@pytest.mark.asyncio
async def test_store_triples_with_uri_object(self, processor):
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
message.triples = [triple]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create object node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, processor, mock_message):
"""Test storing triple with literal object"""
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(mock_message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create literal object
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple operations
first_triple_calls = processor.io.query.call_args_list[0:3]
assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1"
assert first_triple_calls[1][1]["params"]["value"] == "literal object1"
assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1"
# Verify second triple operations
second_triple_calls = processor.io.query.call_args_list[3:6]
assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2"
assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2"
assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2"
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no queries were made
processor.io.query.assert_not_called()
@pytest.mark.asyncio
async def test_store_triples_mixed_objects(self, processor):
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple creates literal
assert "Literal" in processor.io.query.call_args_list[1][0][0]
assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object"
# Verify second triple creates node
assert "Node" in processor.io.query.call_args_list[4][0][0]
assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2"
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_url')
assert args.graph_url == 'falkor://falkordb:6379'
assert hasattr(args, 'database')
assert args.database == 'falkordb'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-url', 'falkor://custom:6379',
'--database', 'custom_db'
])
assert args.graph_url == 'falkor://custom:6379'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'falkor://short:6379'])
assert args.graph_url == 'falkor://short:6379'
@patch('trustgraph.storage.triples.falkordb.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.falkordb.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n"
)
def test_create_node_with_special_characters(self, processor):
"""Test node creation with special characters in URI"""
test_uri = 'http://example.com/node with spaces & symbols'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal_with_special_characters(self, processor):
"""Test literal creation with special characters"""
test_value = 'literal with "quotes" and \n newlines'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)

View file

@ -0,0 +1,441 @@
"""
Tests for Memgraph triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
from trustgraph.schema import Value, Triple
class TestMemgraphStorageProcessor:
"""Test cases for Memgraph storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db:
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
return Processor(
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
password='test_pass',
database='test_db'
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'memgraph'
mock_graph_db.driver.assert_called_once_with(
'bolt://memgraph:7687',
auth=('memgraph', 'password')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='custom_user',
password='custom_pass',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('custom_user', 'custom_pass')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)"
]
assert mock_session.run.call_count == len(expected_calls)
for i, expected_call in enumerate(expected_calls):
actual_call = mock_session.run.call_args_list[i][0][0]
assert actual_call == expected_call
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make all index creation calls raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise an exception
processor = Processor(taskgroup=taskgroup_mock)
# Verify all index creation calls were attempted
assert mock_session.run.call_count == 4
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
uri=test_uri,
database_=processor.db
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
value=test_value,
database_=processor.db
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
database_=processor.db
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
database_=processor.db
)
def test_create_triple_with_uri_object(self, processor):
"""Test triple creation with URI object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create object node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
def test_create_triple_with_literal_object(self, processor):
"""Test triple creation with literal object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create literal object
("MERGE (n:Literal {value: $value})", {'value': 'literal object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
@pytest.mark.asyncio
async def test_store_triples_single_triple(self, processor, mock_message):
"""Test storing a single triple"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
await processor.store_triples(mock_message)
# Verify session was created with correct database
processor.io.session.assert_called_once_with(database=processor.db)
# Verify execute_write was called once per triple
mock_session.execute_write.assert_called_once()
# Verify the triple was passed to create_triple
call_args = mock_session.execute_write.call_args
assert call_args[0][0] == processor.create_triple
assert call_args[0][1] == mock_message.triples[0]
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
await processor.store_triples(message)
# Verify session was called twice (once per triple)
assert processor.io.session.call_count == 2
# Verify execute_write was called once per triple
assert mock_session.execute_write.call_count == 2
# Verify each triple was processed
call_args_list = mock_session.execute_write.call_args_list
assert call_args_list[0][0][1] == triple1
assert call_args_list[1][0][1] == triple2
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()
# Verify no execute_write calls were made
mock_session.execute_write.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
assert args.username == 'memgraph'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'memgraph'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'custom_user',
'--password', 'custom_pass',
'--database', 'custom_db'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'custom_user'
assert args.password == 'custom_pass'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.memgraph.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.memgraph.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n"
)

View file

@ -0,0 +1,548 @@
"""
Tests for Neo4j triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.neo4j.write import Processor
class TestNeo4jStorageProcessor:
"""Test cases for Neo4j storage processor"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'neo4j'
mock_graph_db.driver.assert_called_once_with(
'bolt://neo4j:7687',
auth=('neo4j', 'password')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='testuser',
password='testpass',
database='testdb'
)
assert processor.db == 'testdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('testuser', 'testpass')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation queries were executed
expected_calls = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)"
]
assert mock_session.run.call_count == 3
for expected_query in expected_calls:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make session.run raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise exception - they should be caught and ignored
processor = Processor(taskgroup=taskgroup_mock)
# Should have tried to create all 3 indexes despite exceptions
assert mock_session.run.call_count == 3
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_node(self, mock_graph_db):
"""Test node creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/node",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_literal(self, mock_graph_db):
"""Test literal creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value})",
value="literal value",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_node(self, mock_graph_db):
"""Test node-to-node relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_node
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_literal(self, mock_graph_db):
"""Test node-to-literal relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_literal
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal value"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_handle_triples_with_uri_object(self, mock_graph_db):
"""Test handling triples message with URI object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/object", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, mock_graph_db):
"""Test handling triples message with literal object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with literal object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "literal value"
triple.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
# Verify relate_literal was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value})",
{"value": "literal value", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_multiple_triples(self, mock_graph_db):
"""Test handling message with multiple triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
triple1 = MagicMock()
triple1.s.value = "http://example.com/subject1"
triple1.p.value = "http://example.com/predicate1"
triple1.o.value = "http://example.com/object1"
triple1.o.is_uri = True
triple2 = MagicMock()
triple2.s.value = "http://example.com/subject2"
triple2.p.value = "http://example.com/predicate2"
triple2.o.value = "literal value"
triple2.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
# Triple2: 1 node + 1 literal + 1 relationship = 3 calls
# Total: 6 calls
assert mock_driver.execute_query.call_count == 6
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_empty_triples(self, mock_graph_db):
"""Test handling message with no triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.triples = []
await processor.store_triples(mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
mock_driver.execute_query.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://neo4j:7687'
assert hasattr(args, 'username')
assert args.username == 'neo4j'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'neo4j'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph_host', 'bolt://custom:7687',
'--username', 'testuser',
'--password', 'testpass',
'--database', 'testdb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'testuser'
assert args.password == 'testpass'
assert args.database == 'testdb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.neo4j.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.neo4j.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_special_characters(self, mock_graph_db):
"""Test handling triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
triple = MagicMock()
triple.s.value = "http://example.com/subject with spaces"
triple.p.value = "http://example.com/predicate:with/symbols"
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
triple.o.is_uri = False
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/subject with spaces",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value})",
value='literal with "quotes" and unicode: ñáéíóú',
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
database_="neo4j"
)