Increase storage test coverage (#435)

* Fixing storage and adding tests

* PR pipeline only runs quick tests
This commit is contained in:
cybermaggedon 2025-07-15 09:33:35 +01:00 committed by GitHub
parent 4daa54abaf
commit f37decea2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 7606 additions and 754 deletions

View file

@ -48,8 +48,8 @@ jobs:
- name: Unit tests
run: pytest tests/unit
- name: Integration tests
run: pytest tests/integration
- name: Integration tests (cut the out the long-running tests)
run: pytest tests/integration -m 'not slow'
- name: Contract tests
run: pytest tests/contract

View file

@ -0,0 +1,112 @@
"""
Helper for managing Cassandra containers in integration tests
Alternative to testcontainers for Fedora/Podman compatibility
"""
import subprocess
import time
import socket
from contextlib import contextmanager
from cassandra.cluster import Cluster
from cassandra.policies import RetryPolicy
class CassandraTestContainer:
"""Simple Cassandra container manager using Podman"""
def __init__(self, image="docker.io/library/cassandra:4.1", port=9042):
self.image = image
self.port = port
self.container_name = f"test-cassandra-{int(time.time())}"
self.container_id = None
def start(self):
"""Start Cassandra container"""
# Remove any existing container with same name
subprocess.run([
"podman", "rm", "-f", self.container_name
], capture_output=True)
# Start new container with faster startup options
result = subprocess.run([
"podman", "run", "-d",
"--name", self.container_name,
"-p", f"{self.port}:9042",
"-e", "JVM_OPTS=-Dcassandra.skip_wait_for_gossip_to_settle=0",
self.image
], capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to start container: {result.stderr}")
self.container_id = result.stdout.strip()
# Wait for Cassandra to be ready
self._wait_for_ready()
return self
def stop(self):
"""Stop and remove container"""
import time
if self.container_name:
# Small delay before stopping to ensure connections are closed
time.sleep(0.5)
subprocess.run([
"podman", "rm", "-f", self.container_name
], capture_output=True)
def get_connection_host_port(self):
"""Get host and port for connection"""
return "localhost", self.port
def _wait_for_ready(self, timeout=120):
"""Wait for Cassandra to be ready for CQL queries"""
start_time = time.time()
print(f"Waiting for Cassandra to be ready on port {self.port}...")
while time.time() - start_time < timeout:
try:
# First check if port is open
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
result = sock.connect_ex(("localhost", self.port))
sock.close()
if result == 0:
# Port is open, now try to connect with Cassandra driver
try:
cluster = Cluster(['localhost'], port=self.port)
cluster.connect_timeout = 5
session = cluster.connect()
# Try a simple query to verify Cassandra is ready
session.execute("SELECT release_version FROM system.local")
session.shutdown()
cluster.shutdown()
print("Cassandra is ready!")
return
except Exception as e:
print(f"Cassandra not ready yet: {e}")
pass
except Exception as e:
print(f"Connection check failed: {e}")
pass
time.sleep(3)
raise RuntimeError(f"Cassandra not ready after {timeout} seconds")
@contextmanager
def cassandra_container(image="docker.io/library/cassandra:4.1", port=9042):
"""Context manager for Cassandra container"""
container = CassandraTestContainer(image, port)
try:
container.start()
yield container
finally:
container.stop()

View file

@ -383,4 +383,22 @@ def sample_kg_triples():
# Test markers for integration tests
pytestmark = pytest.mark.integration
pytestmark = pytest.mark.integration
def pytest_sessionfinish(session, exitstatus):
"""
Called after whole test run finished, right before returning the exit status.
This hook is used to ensure Cassandra driver threads have time to shut down
properly before pytest exits, preventing "cannot schedule new futures after
shutdown" errors.
"""
import time
import gc
# Force garbage collection to clean up any remaining objects
gc.collect()
# Give Cassandra driver threads more time to clean up
time.sleep(2)

View file

@ -0,0 +1,411 @@
"""
Cassandra integration tests using Podman containers
These tests verify end-to-end functionality of Cassandra storage and query processors
with real database instances. Compatible with Fedora Linux and Podman.
Uses a single container for all tests to minimize startup time.
"""
import pytest
import asyncio
import time
from unittest.mock import MagicMock
from .cassandra_test_helper import cassandra_container
from trustgraph.direct.cassandra import TrustGraph
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
@pytest.mark.integration
@pytest.mark.slow
class TestCassandraIntegration:
"""Integration tests for Cassandra using a single shared container"""
@pytest.fixture(scope="class")
def cassandra_shared_container(self):
"""Class-level fixture: single Cassandra container for all tests"""
with cassandra_container() as container:
yield container
def setup_method(self):
"""Track all created clients for cleanup"""
self.clients_to_close = []
def teardown_method(self):
"""Clean up all Cassandra connections"""
import gc
for client in self.clients_to_close:
try:
client.close()
except Exception:
pass # Ignore errors during cleanup
# Clear the list and force garbage collection
self.clients_to_close.clear()
gc.collect()
# Small delay to let threads finish
time.sleep(0.5)
@pytest.mark.asyncio
async def test_complete_cassandra_integration(self, cassandra_shared_container):
"""Complete integration test covering all Cassandra functionality"""
container = cassandra_shared_container
host, port = container.get_connection_host_port()
print("=" * 60)
print("RUNNING COMPLETE CASSANDRA INTEGRATION TEST")
print("=" * 60)
# =====================================================
# Test 1: Basic TrustGraph Operations
# =====================================================
print("\n1. Testing basic TrustGraph operations...")
client = TrustGraph(
hosts=[host],
keyspace="test_basic",
table="test_table"
)
self.clients_to_close.append(client)
# Insert test data
client.insert("http://example.org/alice", "knows", "http://example.org/bob")
client.insert("http://example.org/alice", "age", "25")
client.insert("http://example.org/bob", "age", "30")
# Test get_all
all_results = list(client.get_all(limit=10))
assert len(all_results) == 3
print(f"✓ Stored and retrieved {len(all_results)} triples")
# Test get_s (subject query)
alice_results = list(client.get_s("http://example.org/alice", limit=10))
assert len(alice_results) == 2
alice_predicates = [r.p for r in alice_results]
assert "knows" in alice_predicates
assert "age" in alice_predicates
print("✓ Subject queries working")
# Test get_p (predicate query)
age_results = list(client.get_p("age", limit=10))
assert len(age_results) == 2
age_subjects = [r.s for r in age_results]
assert "http://example.org/alice" in age_subjects
assert "http://example.org/bob" in age_subjects
print("✓ Predicate queries working")
# =====================================================
# Test 2: Storage Processor Integration
# =====================================================
print("\n2. Testing storage processor integration...")
storage_processor = StorageProcessor(
taskgroup=MagicMock(),
hosts=[host],
keyspace="test_storage",
table="test_triples"
)
# Track the TrustGraph instance that will be created
self.storage_processor = storage_processor
# Create test message
storage_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
triples=[
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value="Alice Smith", is_uri=False)
),
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value="25", is_uri=False)
),
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/department", is_uri=True),
o=Value(value="Engineering", is_uri=False)
)
]
)
# Store triples via processor
await storage_processor.store_triples(storage_message)
# Track the created TrustGraph instance
if hasattr(storage_processor, 'tg'):
self.clients_to_close.append(storage_processor.tg)
# Verify data was stored
storage_results = list(storage_processor.tg.get_s("http://example.org/person1", limit=10))
assert len(storage_results) == 3
predicates = [row.p for row in storage_results]
objects = [row.o for row in storage_results]
assert "http://example.org/name" in predicates
assert "http://example.org/age" in predicates
assert "http://example.org/department" in predicates
assert "Alice Smith" in objects
assert "25" in objects
assert "Engineering" in objects
print("✓ Storage processor working")
# =====================================================
# Test 3: Query Processor Integration
# =====================================================
print("\n3. Testing query processor integration...")
query_processor = QueryProcessor(
taskgroup=MagicMock(),
hosts=[host],
keyspace="test_query",
table="test_triples"
)
# Use same storage processor for the query keyspace
query_storage_processor = StorageProcessor(
taskgroup=MagicMock(),
hosts=[host],
keyspace="test_query",
table="test_triples"
)
# Store test data for querying
query_test_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
triples=[
Triple(
s=Value(value="http://example.org/alice", is_uri=True),
p=Value(value="http://example.org/knows", is_uri=True),
o=Value(value="http://example.org/bob", is_uri=True)
),
Triple(
s=Value(value="http://example.org/alice", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value="30", is_uri=False)
),
Triple(
s=Value(value="http://example.org/bob", is_uri=True),
p=Value(value="http://example.org/knows", is_uri=True),
o=Value(value="http://example.org/charlie", is_uri=True)
)
]
)
await query_storage_processor.store_triples(query_test_message)
# Debug: Check what was actually stored
print("Debug: Checking what was stored for Alice...")
direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10))
print(f"Direct TrustGraph results: {len(direct_results)}")
for result in direct_results:
print(f" S=http://example.org/alice, P={result.p}, O={result.o}")
# Test S query (find all relationships for Alice)
s_query = TriplesQueryRequest(
s=Value(value="http://example.org/alice", is_uri=True),
p=None, # None for wildcard
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
s_results = await query_processor.query_triples(s_query)
print(f"Query processor results: {len(s_results)}")
for result in s_results:
print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}")
assert len(s_results) == 2
s_predicates = [t.p.value for t in s_results]
assert "http://example.org/knows" in s_predicates
assert "http://example.org/age" in s_predicates
print("✓ Subject queries via processor working")
# Test P query (find all "knows" relationships)
p_query = TriplesQueryRequest(
s=None, # None for wildcard
p=Value(value="http://example.org/knows", is_uri=True),
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
p_results = await query_processor.query_triples(p_query)
print(p_results)
assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie
p_subjects = [t.s.value for t in p_results]
assert "http://example.org/alice" in p_subjects
assert "http://example.org/bob" in p_subjects
print("✓ Predicate queries via processor working")
# =====================================================
# Test 4: Concurrent Operations
# =====================================================
print("\n4. Testing concurrent operations...")
concurrent_processor = StorageProcessor(
taskgroup=MagicMock(),
hosts=[host],
keyspace="test_concurrent",
table="test_triples"
)
# Create multiple coroutines for concurrent storage
async def store_person_data(person_id, name, age, department):
message = Triples(
metadata=Metadata(user="concurrent_test", collection="people"),
triples=[
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value=name, is_uri=False)
),
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value=str(age), is_uri=False)
),
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/department", is_uri=True),
o=Value(value=department, is_uri=False)
)
]
)
await concurrent_processor.store_triples(message)
# Store data for multiple people concurrently
people_data = [
("person1", "John Doe", 25, "Engineering"),
("person2", "Jane Smith", 30, "Marketing"),
("person3", "Bob Wilson", 35, "Engineering"),
("person4", "Alice Brown", 28, "Sales"),
]
# Run storage operations concurrently
store_tasks = [store_person_data(pid, name, age, dept) for pid, name, age, dept in people_data]
await asyncio.gather(*store_tasks)
# Track the created TrustGraph instance
if hasattr(concurrent_processor, 'tg'):
self.clients_to_close.append(concurrent_processor.tg)
# Verify all names were stored
name_results = list(concurrent_processor.tg.get_p("http://example.org/name", limit=10))
assert len(name_results) == 4
stored_names = [r.o for r in name_results]
expected_names = ["John Doe", "Jane Smith", "Bob Wilson", "Alice Brown"]
for name in expected_names:
assert name in stored_names
# Verify department data
dept_results = list(concurrent_processor.tg.get_p("http://example.org/department", limit=10))
assert len(dept_results) == 4
stored_depts = [r.o for r in dept_results]
assert "Engineering" in stored_depts
assert "Marketing" in stored_depts
assert "Sales" in stored_depts
print("✓ Concurrent operations working")
# =====================================================
# Test 5: Complex Queries and Data Integrity
# =====================================================
print("\n5. Testing complex queries and data integrity...")
complex_processor = StorageProcessor(
taskgroup=MagicMock(),
hosts=[host],
keyspace="test_complex",
table="test_triples"
)
# Create a knowledge graph about a company
company_graph = Triples(
metadata=Metadata(user="integration_test", collection="company"),
triples=[
# People and their types
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://company.org/Employee", is_uri=True)
),
Triple(
s=Value(value="http://company.org/bob", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://company.org/Employee", is_uri=True)
),
# Relationships
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/reportsTo", is_uri=True),
o=Value(value="http://company.org/bob", is_uri=True)
),
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/worksIn", is_uri=True),
o=Value(value="http://company.org/engineering", is_uri=True)
),
# Personal info
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/fullName", is_uri=True),
o=Value(value="Alice Johnson", is_uri=False)
),
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/email", is_uri=True),
o=Value(value="alice@company.org", is_uri=False)
),
]
)
# Store the company knowledge graph
await complex_processor.store_triples(company_graph)
# Track the created TrustGraph instance
if hasattr(complex_processor, 'tg'):
self.clients_to_close.append(complex_processor.tg)
# Verify all Alice's data
alice_data = list(complex_processor.tg.get_s("http://company.org/alice", limit=20))
assert len(alice_data) == 5
alice_predicates = [r.p for r in alice_data]
expected_predicates = [
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"http://company.org/reportsTo",
"http://company.org/worksIn",
"http://company.org/fullName",
"http://company.org/email"
]
for pred in expected_predicates:
assert pred in alice_predicates
# Test type-based queries
employee_results = list(complex_processor.tg.get_p("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", limit=10))
print(employee_results)
assert len(employee_results) == 2
employees = [r.s for r in employee_results]
assert "http://company.org/alice" in employees
assert "http://company.org/bob" in employees
print("✓ Complex queries and data integrity working")
# =====================================================
# Summary
# =====================================================
print("\n" + "=" * 60)
print("✅ ALL CASSANDRA INTEGRATION TESTS PASSED!")
print("✅ Basic operations: PASSED")
print("✅ Storage processor: PASSED")
print("✅ Query processor: PASSED")
print("✅ Concurrent operations: PASSED")
print("✅ Complex queries: PASSED")
print("=" * 60)

View file

@ -0,0 +1,456 @@
"""
Tests for Milvus document embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
class TestMilvusDocEmbeddingsQueryProcessor:
"""Test cases for Milvus document embeddings query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') as mock_doc_vectors:
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-de-query',
store_uri='http://localhost:19530'
)
return processor
@pytest.fixture
def mock_query_request(self):
"""Create a mock query request for testing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=10
)
return query
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results
mock_results = [
{"entity": {"doc": "First document chunk"}},
{"entity": {"doc": "Second document chunk"}},
{"entity": {"doc": "Third document chunk"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
# Verify results are document chunks
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor):
"""Test querying document embeddings with multiple vectors"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"doc": "Document from first vector"}},
{"entity": {"doc": "Another doc from first vector"}},
]
mock_results_2 = [
{"entity": {"doc": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 3}),
(([0.4, 0.5, 0.6],), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
assert len(result) == 3
assert "Document from first vector" in result
assert "Another doc from first vector" in result
assert "Document from second vector" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
"""Test querying document embeddings respects limit parameter"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=2
)
# Mock search results - more results than limit
mock_results = [
{"entity": {"doc": "Document 1"}},
{"entity": {"doc": "Document 2"}},
{"entity": {"doc": "Document 3"}},
{"entity": {"doc": "Document 4"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2)
# Verify all results are returned (Milvus handles limit internally)
assert len(result) == 4
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors(self, processor):
"""Test querying document embeddings with empty vectors list"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
limit=5
)
result = await processor.query_document_embeddings(query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_search_results(self, processor):
"""Test querying document embeddings with empty search results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_unicode_documents(self, processor):
"""Test querying document embeddings with Unicode document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with Unicode content
mock_results = [
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
{"entity": {"doc": "Regular ASCII document"}},
{"entity": {"doc": "Document with émojis: 😀🎉"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
"""Test querying document embeddings with large document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with large content
large_doc = "A" * 10000 # 10KB of content
mock_results = [
{"entity": {"doc": large_doc}},
{"entity": {"doc": "Small document"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
"""Test querying document embeddings with special characters in documents"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with special characters
mock_results = [
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying document embeddings with zero limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=0
)
result = await processor.query_document_embeddings(query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to zero limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying document embeddings with negative limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=-1
)
result = await processor.query_document_embeddings(query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to negative limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search to raise exception
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying document embeddings with different vector dimensions"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_document_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
assert "Document from 2D vector" in result
assert "Document from 4D vector" in result
assert "Document from 3D vector" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor):
"""Test querying document embeddings with duplicate documents in results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=5
)
# Mock search results with duplicates across vectors
mock_results_1 = [
{"entity": {"doc": "Document A"}},
{"entity": {"doc": "Document B"}},
]
mock_results_2 = [
{"entity": {"doc": "Document B"}}, # Duplicate
{"entity": {"doc": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
# This preserves ranking and allows multiple occurrences
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result
assert "Document C" in result
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.query.doc_embeddings.milvus.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.doc_embeddings.milvus.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
)

View file

@ -0,0 +1,558 @@
"""
Tests for Pinecone document embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.pinecone.service import Processor
class TestPineconeDocEmbeddingsQueryProcessor:
"""Test cases for Pinecone document embeddings query processor"""
@pytest.fixture
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-de-query',
api_key='test-api-key'
)
return processor
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
@patch('trustgraph.query.doc_embeddings.pinecone.service.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'First document chunk'}),
MagicMock(metadata={'doc': 'Second document chunk'}),
MagicMock(metadata={'doc': 'Third document chunk'})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
mock_index.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
top_k=3,
include_values=False,
include_metadata=True
)
# Verify results
assert len(chunks) == 3
assert chunks[0] == 'First document chunk'
assert chunks[1] == 'Second document chunk'
assert chunks[2] == 'Third document chunk'
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying document embeddings with multiple vectors"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'doc': 'Document chunk 1'}),
MagicMock(metadata={'doc': 'Document chunk 2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'doc': 'Document chunk 3'}),
MagicMock(metadata={'doc': 'Document chunk 4'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both queries
assert len(chunks) == 4
assert 'Document chunk 1' in chunks
assert 'Document chunk 2' in chunks
assert 'Document chunk 3' in chunks
assert 'Document chunk 4' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index with many results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': f'Document chunk {i}'}) for i in range(10)
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
call_args = mock_index.query.call_args
assert call_args[1]['top_k'] == 2
# Results should contain all returned chunks (limit is applied by Pinecone)
assert len(chunks) == 10
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
mock_index_4d.query.return_value = mock_results_4d
chunks = await processor.query_document_embeddings(message)
# Verify different indexes were used
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
# Verify results from both dimensions
assert 'Document from 2D index' in chunks
assert 'Document from 4D index' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify empty results
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'Document with Unicode: éñ中文🚀'}),
MagicMock(metadata={'doc': 'Regular ASCII document'})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
assert 'Document with Unicode: éñ中文🚀' in chunks
assert 'Regular ASCII document' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Create a large document content
large_content = "A" * 10000 # 10KB of content
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': large_content})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify large content is properly handled
assert len(chunks) == 1
assert chunks[0] == large_content
@pytest.mark.asyncio
async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'Short text'}),
MagicMock(metadata={'doc': 'A' * 1000}), # Long text
MagicMock(metadata={'doc': 'Text with numbers: 123 and symbols: @#$'}),
MagicMock(metadata={'doc': ' Whitespace text '}),
MagicMock(metadata={'doc': ''}) # Empty string
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify all content types are properly handled
assert len(chunks) == 5
assert 'Short text' in chunks
assert 'A' * 1000 in chunks
assert 'Text with numbers: 123 and symbols: @#$' in chunks
assert ' Whitespace text ' in chunks
assert '' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
"""Test that results from multiple vectors are properly accumulated"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Each query returns different results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'doc': 'Doc from vector 1.1'}),
MagicMock(metadata={'doc': 'Doc from vector 1.2'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'doc': 'Doc from vector 2.1'})
]
mock_results3 = MagicMock()
mock_results3.matches = [
MagicMock(metadata={'doc': 'Doc from vector 3.1'}),
MagicMock(metadata={'doc': 'Doc from vector 3.2'}),
MagicMock(metadata={'doc': 'Doc from vector 3.3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
# Verify all queries were made
assert mock_index.query.call_count == 3
# Verify all results are accumulated
assert len(chunks) == 6
assert 'Doc from vector 1.1' in chunks
assert 'Doc from vector 1.2' in chunks
assert 'Doc from vector 2.1' in chunks
assert 'Doc from vector 3.1' in chunks
assert 'Doc from vector 3.2' in chunks
assert 'Doc from vector 3.3' in chunks
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.query.doc_embeddings.pinecone.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.doc_embeddings.pinecone.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks. Pinecone implementation.\n"
)

View file

@ -0,0 +1,484 @@
"""
Tests for Milvus graph embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Value, GraphEmbeddingsRequest
class TestMilvusGraphEmbeddingsQueryProcessor:
"""Test cases for Milvus graph embeddings query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') as mock_entity_vectors:
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-ge-query',
store_uri='http://localhost:19530'
)
return processor
@pytest.fixture
def mock_query_request(self):
"""Create a mock query request for testing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=10
)
return query
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "literal entity"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
# Verify results are converted to Value objects
assert len(result) == 3
assert isinstance(result[0], Value)
assert result[0].value == "http://example.com/entity1"
assert result[0].is_uri is True
assert isinstance(result[1], Value)
assert result[1].value == "http://example.com/entity2"
assert result[1].is_uri is True
assert isinstance(result[2], Value)
assert result[2].value == "literal entity"
assert result[2].is_uri is False
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 6}),
(([0.4, 0.5, 0.6],), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
assert len(result) == 3
entity_values = [r.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_with_limit(self, processor):
"""Test querying graph embeddings respects limit parameter"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=2
)
# Mock search results - more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
{"entity": {"entity": "http://example.com/entity4"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
# Verify results are limited to the requested limit
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
# Verify results are limited
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors(self, processor):
"""Test querying graph embeddings with empty vectors list"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
limit=5
)
result = await processor.query_graph_embeddings(query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_search_results(self, processor):
"""Test querying graph embeddings with empty search results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
"""Test querying graph embeddings with mixed URI and literal results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with mixed types
mock_results = [
{"entity": {"entity": "http://example.com/uri_entity"}},
{"entity": {"entity": "literal entity text"}},
{"entity": {"entity": "https://example.com/another_uri"}},
{"entity": {"entity": "another literal"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.is_uri]
assert len(uri_results) == 2
uri_values = [r.value for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.is_uri]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search to raise exception
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.query.graph_embeddings.milvus.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.graph_embeddings.milvus.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph embeddings query service. Input is vector, output is list of\nentities\n"
)
@pytest.mark.asyncio
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying graph embeddings with zero limit"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=0
)
result = await processor.query_graph_embeddings(query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to zero limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying graph embeddings with different vector dimensions"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values

View file

@ -0,0 +1,507 @@
"""
Tests for Pinecone graph embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Value
class TestPineconeGraphEmbeddingsQueryProcessor:
"""Test cases for Pinecone graph embeddings query processor"""
@pytest.fixture
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-ge-query',
api_key='test-api-key'
)
return processor
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
@patch('trustgraph.query.graph_embeddings.pinecone.service.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
def test_create_value_uri(self, processor):
"""Test create_value method for URI entities"""
uri_entity = "http://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert value.value == uri_entity
assert value.is_uri == True
def test_create_value_https_uri(self, processor):
"""Test create_value method for HTTPS URI entities"""
uri_entity = "https://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert value.value == uri_entity
assert value.is_uri == True
def test_create_value_literal(self, processor):
"""Test create_value method for literal entities"""
literal_entity = "literal_entity"
value = processor.create_value(literal_entity)
assert isinstance(value, Value)
assert value.value == literal_entity
assert value.is_uri == False
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'http://example.org/entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'http://example.org/entity3'})
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
mock_index.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
top_k=6, # 2 * limit
include_values=False,
include_metadata=True
)
# Verify results
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].is_uri == True
assert entities[1].value == 'entity2'
assert entities[1].is_uri == False
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].is_uri == True
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify deduplication occurred
entity_values = [e.value for e in entities]
assert len(entity_values) == 3
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index with many results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': f'entity{i}'}) for i in range(10)
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify limit is respected
assert len(entities) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index_4d.query.return_value = mock_results_4d
entities = await processor.query_graph_embeddings(message)
# Verify different indexes were used
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
# Verify results from both dimensions
entity_values = [e.value for e in entities]
assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify empty results
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
"""Test that deduplication works correctly across multiple vector queries"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
entity_values = [e.value for e in entities]
assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.return_value = mock_results1
entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached
mock_index.query.assert_called_once()
assert len(entities) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.query.graph_embeddings.pinecone.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.graph_embeddings.pinecone.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph embeddings query service. Input is vector, output is list of\nentities. Pinecone implementation.\n"
)

View file

@ -0,0 +1,556 @@
"""
Tests for FalkorDB triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.falkordb.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestFalkorDBQueryProcessor:
"""Test cases for FalkorDB query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.falkordb.service.FalkorDB'):
return Processor(
taskgroup=MagicMock(),
id='test-falkordb-query',
graph_url='falkor://localhost:6379'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'falkordb'
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
mock_client.select_graph.assert_called_once_with('falkordb')
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_custom_params(self, mock_falkordb):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(
taskgroup=taskgroup_mock,
graph_url='falkor://custom:6379',
database='customdb'
)
assert processor.db == 'customdb'
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
mock_client.select_graph.assert_called_once_with('customdb')
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_falkordb):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results - both queries return one record each
mock_result = MagicMock()
mock_result.result_set = [["record1"]]
mock_graph.query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_falkordb):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different objects
mock_result1 = MagicMock()
mock_result1.result_set = [["literal result"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/uri_result"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_so_query(self, mock_falkordb):
"""Test SO query (subject and object specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different predicates
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/pred1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/pred2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_s_query(self, mock_falkordb):
"""Test S query (subject only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different predicate-object pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/pred1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/pred2", "http://example.com/uri2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_po_query(self, mock_falkordb):
"""Test PO query (predicate and object specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subjects
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_p_query(self, mock_falkordb):
"""Test P query (predicate only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subject-object pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/uri2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_o_query(self, mock_falkordb):
"""Test O query (object only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subject-predicate pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1", "http://example.com/pred1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/pred2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_falkordb):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/s1", "http://example.com/p1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/s2", "http://example.com/p2", "http://example.com/o2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_falkordb):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query to raise exception
mock_graph.query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_url')
assert args.graph_url == 'falkor://falkordb:6379'
assert hasattr(args, 'database')
assert args.database == 'falkordb'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-url', 'falkor://custom:6379',
'--database', 'querydb'
])
assert args.graph_url == 'falkor://custom:6379'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'falkor://short:6379'])
assert args.graph_url == 'falkor://short:6379'
@patch('trustgraph.query.triples.falkordb.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.falkordb.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for FalkorDB.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)

View file

@ -0,0 +1,568 @@
"""
Tests for Memgraph triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestMemgraphQueryProcessor:
"""Test cases for Memgraph query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.memgraph.service.GraphDatabase'):
return Processor(
taskgroup=MagicMock(),
id='test-memgraph-query',
graph_host='bolt://localhost:7687'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'memgraph'
mock_graph_db.driver.assert_called_once_with(
'bolt://memgraph:7687',
auth=('memgraph', 'password')
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='queryuser',
password='querypass',
database='customdb'
)
assert processor.db == 'customdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('queryuser', 'querypass')
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_graph_db):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results - both queries return one record each
mock_records = [MagicMock()]
mock_driver.execute_query.return_value = (mock_records, None, None)
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_graph_db):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different objects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"dest": "literal result"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_so_query(self, mock_graph_db):
"""Test SO query (subject and object specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different predicates
mock_record1 = MagicMock()
mock_record1.data.return_value = {"rel": "http://example.com/pred1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"rel": "http://example.com/pred2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_s_query(self, mock_graph_db):
"""Test S query (subject only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different predicate-object pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"rel": "http://example.com/pred1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"rel": "http://example.com/pred2", "dest": "http://example.com/uri2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_po_query(self, mock_graph_db):
"""Test PO query (predicate and object specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subjects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_p_query(self, mock_graph_db):
"""Test P query (predicate only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subject-object pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2", "dest": "http://example.com/uri2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_o_query(self, mock_graph_db):
"""Test O query (object only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subject-predicate pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1", "rel": "http://example.com/pred1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2", "rel": "http://example.com/pred2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_graph_db):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_graph_db):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock execute_query to raise exception
mock_driver.execute_query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
assert args.username == 'memgraph'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'memgraph'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'queryuser',
'--password', 'querypass',
'--database', 'querydb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'queryuser'
assert args.password == 'querypass'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.query.triples.memgraph.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.memgraph.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for memgraph.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)

View file

@ -0,0 +1,338 @@
"""
Tests for Neo4j triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestNeo4jQueryProcessor:
"""Test cases for Neo4j query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.neo4j.service.GraphDatabase'):
return Processor(
taskgroup=MagicMock(),
id='test-neo4j-query',
graph_host='bolt://localhost:7687'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'neo4j'
mock_graph_db.driver.assert_called_once_with(
'bolt://neo4j:7687',
auth=('neo4j', 'password')
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='queryuser',
password='querypass',
database='customdb'
)
assert processor.db == 'customdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('queryuser', 'querypass')
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_graph_db):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results - both queries return one record each
mock_records = [MagicMock()]
mock_driver.execute_query.return_value = (mock_records, None, None)
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_graph_db):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different objects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"dest": "literal result"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_graph_db):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_graph_db):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock execute_query to raise exception
mock_driver.execute_query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://neo4j:7687'
assert hasattr(args, 'username')
assert args.username == 'neo4j'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'neo4j'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'queryuser',
'--password', 'querypass',
'--database', 'querydb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'queryuser'
assert args.password == 'querypass'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.query.triples.neo4j.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.neo4j.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for neo4j.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)

View file

@ -0,0 +1,387 @@
"""
Tests for Milvus document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.doc_embeddings.milvus.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestMilvusDocEmbeddingsStorageProcessor:
"""Test cases for Milvus document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors:
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-de-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], "Test document content"),
([0.4, 0.5, 0.6], "Test document content"),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk
expected_calls = [
# Chunk 1 vectors
([0.1, 0.2, 0.3], "This is the first document chunk"),
([0.4, 0.5, 0.6], "This is the first document chunk"),
# Chunk 2 vectors
([0.7, 0.8, 0.9], "This is the second document chunk"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for None chunk
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
"""Test storing document embeddings with mix of valid and invalid chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
chunk=b"Valid document content",
vectors=[[0.1, 0.2, 0.3]]
)
empty_chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.4, 0.5, 0.6]]
)
none_chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [valid_chunk, empty_chunk, none_chunk]
await processor.store_document_embeddings(message)
# Verify only valid chunk was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Valid document content"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], "Document with mixed dimensions"),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"),
([0.7, 0.8, 0.9], "Document with mixed dimensions"),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀"
)
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify large content was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], large_content
)
@pytest.mark.asyncio
async def test_store_document_embeddings_whitespace_only_chunk(self, processor):
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify whitespace content was inserted (not filtered out)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], " \n\t "
)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,536 @@
"""
Tests for Pinecone document embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.doc_embeddings.pinecone.write import Processor
from trustgraph.schema import ChunkEmbeddings
class TestPineconeDocEmbeddingsStorageProcessor:
"""Test cases for Pinecone document embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-de-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_document_embeddings_single_chunk(self, processor):
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
# Verify index name and operations
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['doc'] == "Test document content"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['doc'] == "Test document content"
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify document content in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_decoded_chunk(self, processor):
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.chunks = [chunk]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == "Document with Unicode: éñ中文🚀"
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
# Verify large content was stored
call_args = mock_index.upsert.call_args
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
assert stored_doc == large_content
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,354 @@
"""
Tests for Milvus graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.graph_embeddings.milvus.write import Processor
from trustgraph.schema import Value, EntityEmbeddings
class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test cases for Milvus graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Value(value='http://example.com/entity1', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value='literal entity', is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors:
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-ge-storage',
store_uri='http://localhost:19530'
)
return processor
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify insert was called for each vector
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/entity'),
([0.4, 0.5, 0.6], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity
expected_calls = [
# Entity 1 vectors
([0.1, 0.2, 0.3], 'http://example.com/entity1'),
([0.4, 0.5, 0.6], 'http://example.com/entity1'),
# Entity 2 vectors
([0.7, 0.8, 0.9], 'literal entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor):
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
entity=Value(value='http://example.com/valid', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
empty_entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
none_entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
# Verify only valid entity was inserted
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], 'http://example.com/valid'
)
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], 'http://example.com/entity'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
([0.7, 0.8, 0.9], 'http://example.com/entity'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
@pytest.mark.asyncio
async def test_store_graph_embeddings_uri_and_literal_entities(self, processor):
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
entity=Value(value='http://example.com/uri_entity', is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
literal_entity = EntityEmbeddings(
entity=Value(value='literal entity text', is_uri=False),
vectors=[[0.4, 0.5, 0.6]]
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
# Verify both entities were inserted
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/uri_entity'),
([0.4, 0.5, 0.6], 'literal entity text'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
)

View file

@ -0,0 +1,460 @@
"""
Tests for Pinecone graph embeddings storage service
"""
import pytest
from unittest.mock import MagicMock, patch
import uuid
from trustgraph.storage.graph_embeddings.pinecone.write import Processor
from trustgraph.schema import EntityEmbeddings, Value
class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test cases for Pinecone graph embeddings storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings
entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-ge-storage',
api_key='test-api-key'
)
return processor
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
assert processor.cloud == 'aws'
assert processor.region == 'us-east-1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key',
cloud='gcp',
region='us-west1'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
assert processor.cloud == 'gcp'
assert processor.region == 'us-west1'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_store_graph_embeddings_single_entity(self, processor):
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_graph_embeddings(message)
# Verify index name and operations
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
first_vectors = first_call[1]['vectors']
assert len(first_vectors) == 1
assert first_vectors[0]['id'] == 'id1'
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
# Check second vector upsert
second_call = mock_index.upsert.call_args_list[1]
second_vectors = second_call[1]['vectors']
assert len(second_vectors) == 1
assert second_vectors[0]['id'] == 'id2'
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(mock_message)
# Verify upsert was called for each vector (3 total)
assert mock_index.upsert.call_count == 3
# Verify entity values in metadata
calls = mock_index.upsert.call_args_list
assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist initially
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_none_entity_value(self, processor):
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
)
message.entities = [entity]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[]
)
message.entities = [entity]
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and creation fails
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
await processor.store_graph_embeddings(message)
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and never becomes ready
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added by parsing empty args
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
assert hasattr(args, 'cloud')
assert args.cloud == 'aws'
assert hasattr(args, 'region')
assert args.region == 'us-east-1'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io',
'--cloud', 'gcp',
'--region', 'us-west1'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
assert args.cloud == 'gcp'
assert args.region == 'us-west1'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nAccepts entity/vector pairs and writes them to a Pinecone store.\n"
)

View file

@ -0,0 +1,436 @@
"""
Tests for FalkorDB triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.falkordb.write import Processor
from trustgraph.schema import Value, Triple
class TestFalkorDBStorageProcessor:
"""Test cases for FalkorDB storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb:
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
return Processor(
taskgroup=MagicMock(),
id='test-falkordb-storage',
graph_url='falkor://localhost:6379',
database='test_db'
)
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'falkordb'
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
mock_client.select_graph.assert_called_once_with('falkordb')
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
def test_processor_initialization_with_custom_params(self, mock_falkordb):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(
taskgroup=taskgroup_mock,
graph_url='falkor://custom:6379',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
mock_client.select_graph.assert_called_once_with('custom_db')
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
},
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
},
)
@pytest.mark.asyncio
async def test_store_triples_with_uri_object(self, processor):
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
message.triples = [triple]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create object node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, processor, mock_message):
"""Test storing triple with literal object"""
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(mock_message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
# Create literal object
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
]
assert processor.io.query.call_count == 3
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.io.query.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple operations
first_triple_calls = processor.io.query.call_args_list[0:3]
assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1"
assert first_triple_calls[1][1]["params"]["value"] == "literal object1"
assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1"
# Verify second triple operations
second_triple_calls = processor.io.query.call_args_list[3:6]
assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2"
assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2"
assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2"
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no queries were made
processor.io.query.assert_not_called()
@pytest.mark.asyncio
async def test_store_triples_mixed_objects(self, processor):
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
# Verify first triple creates literal
assert "Literal" in processor.io.query.call_args_list[1][0][0]
assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object"
# Verify second triple creates node
assert "Node" in processor.io.query.call_args_list[4][0][0]
assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2"
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_url')
assert args.graph_url == 'falkor://falkordb:6379'
assert hasattr(args, 'database')
assert args.database == 'falkordb'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-url', 'falkor://custom:6379',
'--database', 'custom_db'
])
assert args.graph_url == 'falkor://custom:6379'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'falkor://short:6379'])
assert args.graph_url == 'falkor://short:6379'
@patch('trustgraph.storage.triples.falkordb.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.falkordb.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n"
)
def test_create_node_with_special_characters(self, processor):
"""Test node creation with special characters in URI"""
test_uri = 'http://example.com/node with spaces & symbols'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
params={
"uri": test_uri,
},
)
def test_create_literal_with_special_characters(self, processor):
"""Test literal creation with special characters"""
test_value = 'literal with "quotes" and \n newlines'
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
params={
"value": test_value,
},
)

View file

@ -0,0 +1,441 @@
"""
Tests for Memgraph triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
from trustgraph.schema import Value, Triple
class TestMemgraphStorageProcessor:
"""Test cases for Memgraph storage processor"""
@pytest.fixture
def mock_message(self):
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
message.triples = [triple]
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db:
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
return Processor(
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
password='test_pass',
database='test_db'
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'memgraph'
mock_graph_db.driver.assert_called_once_with(
'bolt://memgraph:7687',
auth=('memgraph', 'password')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='custom_user',
password='custom_pass',
database='custom_db'
)
assert processor.db == 'custom_db'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('custom_user', 'custom_pass')
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)"
]
assert mock_session.run.call_count == len(expected_calls)
for i, expected_call in enumerate(expected_calls):
actual_call = mock_session.run.call_args_list[i][0][0]
assert actual_call == expected_call
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_session = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make all index creation calls raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise an exception
processor = Processor(taskgroup=taskgroup_mock)
# Verify all index creation calls were attempted
assert mock_session.run.call_count == 4
def test_create_node(self, processor):
"""Test node creation"""
test_uri = 'http://example.com/node'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
uri=test_uri,
database_=processor.db
)
def test_create_literal(self, processor):
"""Test literal creation"""
test_value = 'test literal value'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value)
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
value=test_value,
database_=processor.db
)
def test_relate_node(self, processor):
"""Test node-to-node relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
database_=processor.db
)
def test_relate_literal(self, processor):
"""Test node-to-literal relationship creation"""
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 5
mock_result.summary = mock_summary
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
database_=processor.db
)
def test_create_triple_with_uri_object(self, processor):
"""Test triple creation with URI object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create object node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
def test_create_triple_with_literal_object(self, processor):
"""Test triple creation with literal object"""
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
)
processor.create_triple(mock_tx, triple)
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
# Create literal object
("MERGE (n:Literal {value: $value})", {'value': 'literal object'}),
# Create relationship
("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'})
]
assert mock_tx.run.call_count == 3
for i, (expected_query, expected_params) in enumerate(expected_calls):
actual_call = mock_tx.run.call_args_list[i]
assert actual_call[0][0] == expected_query
assert actual_call[1] == expected_params
@pytest.mark.asyncio
async def test_store_triples_single_triple(self, processor, mock_message):
"""Test storing a single triple"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
await processor.store_triples(mock_message)
# Verify session was created with correct database
processor.io.session.assert_called_once_with(database=processor.db)
# Verify execute_write was called once per triple
mock_session.execute_write.assert_called_once()
# Verify the triple was passed to create_triple
call_args = mock_session.execute_write.call_args
assert call_args[0][0] == processor.create_triple
assert call_args[0][1] == mock_message.triples[0]
@pytest.mark.asyncio
async def test_store_triples_multiple_triples(self, processor):
"""Test storing multiple triples"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
)
message.triples = [triple1, triple2]
await processor.store_triples(message)
# Verify session was called twice (once per triple)
assert processor.io.session.call_count == 2
# Verify execute_write was called once per triple
assert mock_session.execute_write.call_count == 2
# Verify each triple was processed
call_args_list = mock_session.execute_write.call_args_list
assert call_args_list[0][0][1] == triple1
assert call_args_list[1][0][1] == triple2
@pytest.mark.asyncio
async def test_store_triples_empty_list(self, processor):
"""Test storing empty triples list"""
mock_session = MagicMock()
processor.io.session.return_value.__enter__.return_value = mock_session
# Reset the mock to clear the initialization call
processor.io.session.reset_mock()
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()
# Verify no execute_write calls were made
mock_session.execute_write.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
assert args.username == 'memgraph'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'memgraph'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'custom_user',
'--password', 'custom_pass',
'--database', 'custom_db'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'custom_user'
assert args.password == 'custom_pass'
assert args.database == 'custom_db'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.memgraph.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.memgraph.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n"
)

View file

@ -0,0 +1,548 @@
"""
Tests for Neo4j triples storage service
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.neo4j.write import Processor
class TestNeo4jStorageProcessor:
"""Test cases for Neo4j storage processor"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'neo4j'
mock_graph_db.driver.assert_called_once_with(
'bolt://neo4j:7687',
auth=('neo4j', 'password')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='testuser',
password='testpass',
database='testdb'
)
assert processor.db == 'testdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('testuser', 'testpass')
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_success(self, mock_graph_db):
"""Test successful index creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Verify index creation queries were executed
expected_calls = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)"
]
assert mock_session.run.call_count == 3
for expected_query in expected_calls:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_indexes_with_exceptions(self, mock_graph_db):
"""Test index creation with exceptions (should be ignored)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Make session.run raise exceptions
mock_session.run.side_effect = Exception("Index already exists")
# Should not raise exception - they should be caught and ignored
processor = Processor(taskgroup=taskgroup_mock)
# Should have tried to create all 3 indexes despite exceptions
assert mock_session.run.call_count == 3
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_node(self, mock_graph_db):
"""Test node creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/node",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_create_literal(self, mock_graph_db):
"""Test literal creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value})",
value="literal value",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_node(self, mock_graph_db):
"""Test node-to-node relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_node
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_relate_literal(self, mock_graph_db):
"""Test node-to-literal relationship creation"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Test relate_literal
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal value"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
database_="neo4j"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_handle_triples_with_uri_object(self, mock_graph_db):
"""Test handling triples message with URI object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/object", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_literal_object(self, mock_graph_db):
"""Test handling triples message with literal object"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triple with literal object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "literal value"
triple.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
# Verify relate_literal was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri})",
{"uri": "http://example.com/subject", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value})",
{"value": "literal value", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"database_": "neo4j"
}
)
]
assert mock_driver.execute_query.call_count == 3
for expected_query, expected_params in expected_calls:
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_multiple_triples(self, mock_graph_db):
"""Test handling message with multiple triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
triple1 = MagicMock()
triple1.s.value = "http://example.com/subject1"
triple1.p.value = "http://example.com/predicate1"
triple1.o.value = "http://example.com/object1"
triple1.o.is_uri = True
triple2 = MagicMock()
triple2.s.value = "http://example.com/subject2"
triple2.p.value = "http://example.com/predicate2"
triple2.o.value = "literal value"
triple2.o.is_uri = False
# Create mock message
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
# Triple2: 1 node + 1 literal + 1 relationship = 3 calls
# Total: 6 calls
assert mock_driver.execute_query.call_count == 6
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_empty_triples(self, mock_graph_db):
"""Test handling message with no triples"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.triples = []
await processor.store_triples(mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
mock_driver.execute_query.assert_not_called()
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://neo4j:7687'
assert hasattr(args, 'username')
assert args.username == 'neo4j'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'neo4j'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph_host', 'bolt://custom:7687',
'--username', 'testuser',
'--password', 'testpass',
'--database', 'testdb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'testuser'
assert args.password == 'testpass'
assert args.database == 'testdb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.storage.triples.neo4j.write.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.storage.triples.neo4j.write import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n"
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_special_characters(self, mock_graph_db):
"""Test handling triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
triple = MagicMock()
triple.s.value = "http://example.com/subject with spaces"
triple.p.value = "http://example.com/predicate:with/symbols"
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
triple.o.is_uri = False
mock_message = MagicMock()
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri})",
uri="http://example.com/subject with spaces",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value})",
value='literal with "quotes" and unicode: ñáéíóú',
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
database_="neo4j"
)

View file

@ -3,6 +3,9 @@ from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
# Global list to track clusters for cleanup
_active_clusters = []
class TrustGraph:
def __init__(
@ -24,6 +27,9 @@ class TrustGraph:
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
# Track this cluster globally
_active_clusters.append(self.cluster)
self.init()
@ -119,3 +125,13 @@ class TrustGraph:
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)
def close(self):
"""Close the Cassandra session and cluster connections properly"""
if hasattr(self, 'session') and self.session:
self.session.shutdown()
if hasattr(self, 'cluster') and self.cluster:
self.cluster.shutdown()
# Remove from global tracking
if self.cluster in _active_clusters:
_active_clusters.remove(self.cluster)

View file

@ -5,94 +5,56 @@ of chunks
"""
from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
from .... base import DocumentEmbeddingsQueryService
module = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_ident = "de-query"
default_store_uri = 'http://localhost:19530'
class Processor(ConsumerProducer):
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"store_uri": store_uri,
}
)
self.vecstore = DocVectors(store_uri)
async def handle(self, msg):
async def query_document_embeddings(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
# Handle zero limit case
if msg.limit <= 0:
return []
chunks = []
for vec in v.vectors:
for vec in msg.vectors:
resp = self.vecstore.search(vec, limit=v.limit)
resp = self.vecstore.search(vec, limit=msg.limit)
for r in resp:
chunk = r["entity"]["doc"]
chunk = chunk.encode("utf-8")
chunks.append(chunk)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return chunks
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
DocumentEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
@ -102,5 +64,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -10,30 +10,21 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
from .... base import DocumentEmbeddingsQueryService
module = "de-query"
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_ident = "de-query"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None or self.api_key == "not-specified":
raise RuntimeError("Pinecone API key must be specified")
if self.url:
self.pinecone = PineconeGRPC(
@ -47,88 +38,53 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"url": self.url,
"api_key": self.api_key,
}
)
async def handle(self, msg):
async def query_document_embeddings(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
# Handle zero limit case
if msg.limit <= 0:
return []
chunks = []
for vec in v.vectors:
for vec in msg.vectors:
dim = len(vec)
index_name = (
"d-" + v.user + "-" + str(dim)
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=v.limit,
with_payload=True,
).points
for r in results.matches:
doc = r.metadata["doc"]
chunks.add(doc)
chunks.append(doc)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return chunks
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
DocumentEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-a', '--api-key',
@ -143,5 +99,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -5,35 +5,21 @@ entities
"""
from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
from .... base import GraphEmbeddingsQueryService
module = "ge-query"
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_ident = "ge-query"
default_store_uri = 'http://localhost:19530'
class Processor(ConsumerProducer):
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"store_uri": store_uri,
}
)
@ -46,29 +32,34 @@ class Processor(ConsumerProducer):
else:
return Value(value=ent, is_uri=False)
async def handle(self, msg):
async def query_graph_embeddings(self, msg):
try:
v = msg.value()
entity_set = set()
entities = []
# Sender-produced ID
id = msg.properties()["id"]
# Handle zero limit case
if msg.limit <= 0:
return []
print(f"Handling input {id}...", flush=True)
for vec in msg.vectors:
entities = set()
for vec in v.vectors:
resp = self.vecstore.search(vec, limit=v.limit)
resp = self.vecstore.search(vec, limit=msg.limit * 2)
for r in resp:
ent = r["entity"]["entity"]
entities.add(ent)
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Convert set to list
entities = list(entities)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
@ -78,36 +69,19 @@ class Processor(ConsumerProducer):
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
await self.send(r, properties={"id": id})
return entities
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
entities=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
GraphEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
@ -117,5 +91,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -10,30 +10,23 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
from .... base import GraphEmbeddingsQueryService
module = "ge-query"
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_ident = "ge-query"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None or self.api_key == "not-specified":
raise RuntimeError("Pinecone API key must be specified")
if self.url:
self.pinecone = PineconeGRPC(
@ -47,12 +40,8 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"url": self.url,
"api_key": self.api_key,
}
)
@ -62,26 +51,23 @@ class Processor(ConsumerProducer):
else:
return Value(value=ent, is_uri=False)
async def handle(self, msg):
async def query_graph_embeddings(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
# Handle zero limit case
if msg.limit <= 0:
return []
entity_set = set()
entities = []
for vec in v.vectors:
for vec in msg.vectors:
dim = len(vec)
index_name = (
"t-" + v.user + "-" + str(dim)
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
@ -89,9 +75,8 @@ class Processor(ConsumerProducer):
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit * 2,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
@ -106,10 +91,10 @@ class Processor(ConsumerProducer):
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
if len(entity_set) >= msg.limit: break
ents2 = []
@ -118,37 +103,17 @@ class Processor(ConsumerProducer):
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return entities
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
entities=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
GraphEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-a', '--api-key',
@ -163,5 +128,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -9,37 +9,24 @@ from falkordb import FalkorDB
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
from .... base import TriplesQueryService
module = "triples-query"
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
default_subscriber = module
default_ident = "triples-query"
default_graph_url = 'falkor://falkordb:6379'
default_database = 'falkordb'
class Processor(ConsumerProducer):
class Processor(TriplesQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_url = params.get("graph_host", default_graph_url)
graph_url = params.get("graph_url", default_graph_url)
database = params.get("database", default_database)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TriplesQueryRequest,
"output_schema": TriplesQueryResponse,
"graph_url": graph_url,
"database": database,
}
)
@ -54,50 +41,45 @@ class Processor(ConsumerProducer):
else:
return Value(value=ent, is_uri=False)
async def handle(self, msg):
async def query_triples(self, query):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = []
if v.s is not None:
if v.p is not None:
if v.o is not None:
if query.s is not None:
if query.p is not None:
if query.o is not None:
# SPO
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"RETURN $src as src",
"RETURN $src as src "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"rel": v.p.value,
"value": v.o.value,
"src": query.s.value,
"rel": query.p.value,
"value": query.o.value,
},
).result_set
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"RETURN $src as src",
"RETURN $src as src "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"rel": v.p.value,
"uri": v.o.value,
"src": query.s.value,
"rel": query.p.value,
"uri": query.o.value,
},
).result_set
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
else:
@ -105,116 +87,124 @@ class Processor(ConsumerProducer):
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"RETURN dest.value as dest",
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"rel": v.p.value,
"src": query.s.value,
"rel": query.p.value,
},
).result_set
for rec in records:
triples.append((v.s.value, v.p.value, rec[0]))
triples.append((query.s.value, query.p.value, rec[0]))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"RETURN dest.uri as dest",
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"rel": v.p.value,
"src": query.s.value,
"rel": query.p.value,
},
).result_set
for rec in records:
triples.append((v.s.value, v.p.value, rec[0]))
triples.append((query.s.value, query.p.value, rec[0]))
else:
if v.o is not None:
if query.o is not None:
# SO
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN rel.uri as rel",
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"value": v.o.value,
"src": query.s.value,
"value": query.o.value,
},
).result_set
for rec in records:
triples.append((v.s.value, rec[0], v.o.value))
triples.append((query.s.value, rec[0], query.o.value))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN rel.uri as rel",
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"uri": v.o.value,
"src": query.s.value,
"uri": query.o.value,
},
).result_set
for rec in records:
triples.append((v.s.value, rec[0], v.o.value))
triples.append((query.s.value, rec[0], query.o.value))
else:
# s
records = self.io.query(
"match (src:node {uri: $src})-[rel:rel]->(dest:literal) "
"return rel.uri as rel, dest.value as dest",
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"src": query.s.value,
},
).result_set
for rec in records:
triples.append((v.s.value, rec[0], rec[1]))
triples.append((query.s.value, rec[0], rec[1]))
records = self.io.query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"RETURN rel.uri as rel, dest.uri as dest",
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"src": v.s.value,
"src": query.s.value,
},
).result_set
for rec in records:
triples.append((v.s.value, rec[0], rec[1]))
triples.append((query.s.value, rec[0], rec[1]))
else:
if v.p is not None:
if query.p is not None:
if v.o is not None:
if query.o is not None:
# PO
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
"RETURN src.uri as src",
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
params={
"uri": v.p.value,
"value": v.o.value,
"uri": query.p.value,
"value": query.o.value,
},
).result_set
for rec in records:
triples.append((rec[0], v.p.value, v.o.value))
triples.append((rec[0], query.p.value, query.o.value))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src",
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
params={
"uri": v.p.value,
"dest": v.o.value,
"uri": query.p.value,
"dest": query.o.value,
},
).result_set
for rec in records:
triples.append((rec[0], v.p.value, v.o.value))
triples.append((rec[0], query.p.value, query.o.value))
else:
@ -222,53 +212,57 @@ class Processor(ConsumerProducer):
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"RETURN src.uri as src, dest.value as dest",
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
params={
"uri": v.p.value,
"uri": query.p.value,
},
).result_set
for rec in records:
triples.append((rec[0], v.p.value, rec[1]))
triples.append((rec[0], query.p.value, rec[1]))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest",
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
params={
"uri": v.p.value,
"uri": query.p.value,
},
).result_set
for rec in records:
triples.append((rec[0], v.p.value, rec[1]))
triples.append((rec[0], query.p.value, rec[1]))
else:
if v.o is not None:
if query.o is not None:
# O
records = self.io.query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN src.uri as src, rel.uri as rel",
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"value": v.o.value,
"value": query.o.value,
},
).result_set
for rec in records:
triples.append((rec[0], rec[1], v.o.value))
triples.append((rec[0], rec[1], query.o.value))
records = self.io.query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel",
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
params={
"uri": v.o.value,
"uri": query.o.value,
},
).result_set
for rec in records:
triples.append((rec[0], rec[1], v.o.value))
triples.append((rec[0], rec[1], query.o.value))
else:
@ -276,7 +270,8 @@ class Processor(ConsumerProducer):
records = self.io.query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
).result_set
for rec in records:
@ -284,7 +279,8 @@ class Processor(ConsumerProducer):
records = self.io.query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
).result_set
for rec in records:
@ -296,40 +292,20 @@ class Processor(ConsumerProducer):
p=self.create_value(t[1]),
o=self.create_value(t[2])
)
for t in triples
for t in triples[:query.limit]
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return triples
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
TriplesQueryService.add_args(parser)
parser.add_argument(
'-g', '--graph-url',
@ -345,5 +321,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -9,28 +9,19 @@ from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
from .... base import TriplesQueryService
module = "triples-query"
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
default_subscriber = module
default_ident = "triples-query"
default_graph_host = 'bolt://memgraph:7687'
default_username = 'memgraph'
default_password = 'password'
default_database = 'memgraph'
class Processor(ConsumerProducer):
class Processor(TriplesQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
username = params.get("username", default_username)
password = params.get("password", default_password)
@ -38,12 +29,9 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TriplesQueryRequest,
"output_schema": TriplesQueryResponse,
"graph_host": graph_host,
"username": username,
"database": database,
}
)
@ -58,46 +46,39 @@ class Processor(ConsumerProducer):
else:
return Value(value=ent, is_uri=False)
async def handle(self, msg):
async def query_triples(self, query):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = []
if v.s is not None:
if v.p is not None:
if v.o is not None:
if query.s is not None:
if query.p is not None:
if query.o is not None:
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"RETURN $src as src "
"LIMIT " + str(v.limit),
src=v.s.value, rel=v.p.value, value=v.o.value,
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, value=query.o.value,
database_=self.db,
)
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"RETURN $src as src "
"LIMIT " + str(v.limit),
src=v.s.value, rel=v.p.value, uri=v.o.value,
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value, uri=query.o.value,
database_=self.db,
)
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
else:
@ -106,56 +87,56 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"RETURN dest.value as dest "
"LIMIT " + str(v.limit),
src=v.s.value, rel=v.p.value,
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
triples.append((query.s.value, query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"RETURN dest.uri as dest "
"LIMIT " + str(v.limit),
src=v.s.value, rel=v.p.value,
"LIMIT " + str(query.limit),
src=query.s.value, rel=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
triples.append((query.s.value, query.p.value, data["dest"]))
else:
if v.o is not None:
if query.o is not None:
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN rel.uri as rel "
"LIMIT " + str(v.limit),
src=v.s.value, value=v.o.value,
"LIMIT " + str(query.limit),
src=query.s.value, value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
triples.append((query.s.value, data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN rel.uri as rel "
"LIMIT " + str(v.limit),
src=v.s.value, uri=v.o.value,
"LIMIT " + str(query.limit),
src=query.s.value, uri=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
triples.append((query.s.value, data["rel"], query.o.value))
else:
@ -164,59 +145,59 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(v.limit),
src=v.s.value,
"LIMIT " + str(query.limit),
src=query.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
triples.append((query.s.value, data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(v.limit),
src=v.s.value,
"LIMIT " + str(query.limit),
src=query.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
triples.append((query.s.value, data["rel"], data["dest"]))
else:
if v.p is not None:
if query.p is not None:
if v.o is not None:
if query.o is not None:
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
"RETURN src.uri as src "
"LIMIT " + str(v.limit),
uri=v.p.value, value=v.o.value,
"LIMIT " + str(query.limit),
uri=query.p.value, value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, v.o.value))
triples.append((data["src"], query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"RETURN src.uri as src "
"LIMIT " + str(v.limit),
uri=v.p.value, dest=v.o.value,
"LIMIT " + str(query.limit),
uri=query.p.value, dest=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, v.o.value))
triples.append((data["src"], query.p.value, query.o.value))
else:
@ -225,56 +206,56 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(v.limit),
uri=v.p.value,
"LIMIT " + str(query.limit),
uri=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
triples.append((data["src"], query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(v.limit),
uri=v.p.value,
"LIMIT " + str(query.limit),
uri=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
triples.append((data["src"], query.p.value, data["dest"]))
else:
if v.o is not None:
if query.o is not None:
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(v.limit),
value=v.o.value,
"LIMIT " + str(query.limit),
value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
triples.append((data["src"], data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(v.limit),
uri=v.o.value,
"LIMIT " + str(query.limit),
uri=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
triples.append((data["src"], data["rel"], query.o.value))
else:
@ -283,7 +264,7 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(v.limit),
"LIMIT " + str(query.limit),
database_=self.db,
)
@ -294,7 +275,7 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(v.limit),
"LIMIT " + str(query.limit),
database_=self.db,
)
@ -308,40 +289,22 @@ class Processor(ConsumerProducer):
p=self.create_value(t[1]),
o=self.create_value(t[2])
)
for t in triples[:v.limit]
for t in triples[:query.limit]
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return triples
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
TriplesQueryService.add_args(parser)
parser.add_argument(
'-g', '--graph-host',
@ -369,5 +332,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -9,28 +9,19 @@ from neo4j import GraphDatabase
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... schema import triples_request_queue
from .... schema import triples_response_queue
from .... base import ConsumerProducer
from .... base import TriplesQueryService
module = "triples-query"
default_input_queue = triples_request_queue
default_output_queue = triples_response_queue
default_subscriber = module
default_ident = "triples-query"
default_graph_host = 'bolt://neo4j:7687'
default_username = 'neo4j'
default_password = 'password'
default_database = 'neo4j'
class Processor(ConsumerProducer):
class Processor(TriplesQueryService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
username = params.get("username", default_username)
password = params.get("password", default_password)
@ -38,12 +29,9 @@ class Processor(ConsumerProducer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TriplesQueryRequest,
"output_schema": TriplesQueryResponse,
"graph_host": graph_host,
"username": username,
"database": database,
}
)
@ -58,44 +46,37 @@ class Processor(ConsumerProducer):
else:
return Value(value=ent, is_uri=False)
async def handle(self, msg):
async def query_triples(self, query):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
triples = []
if v.s is not None:
if v.p is not None:
if v.o is not None:
if query.s is not None:
if query.p is not None:
if query.o is not None:
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
"RETURN $src as src",
src=v.s.value, rel=v.p.value, value=v.o.value,
src=query.s.value, rel=query.p.value, value=query.o.value,
database_=self.db,
)
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
"RETURN $src as src",
src=v.s.value, rel=v.p.value, uri=v.o.value,
src=query.s.value, rel=query.p.value, uri=query.o.value,
database_=self.db,
)
for rec in records:
triples.append((v.s.value, v.p.value, v.o.value))
triples.append((query.s.value, query.p.value, query.o.value))
else:
@ -104,52 +85,52 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
"RETURN dest.value as dest",
src=v.s.value, rel=v.p.value,
src=query.s.value, rel=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
triples.append((query.s.value, query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
"RETURN dest.uri as dest",
src=v.s.value, rel=v.p.value,
src=query.s.value, rel=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, v.p.value, data["dest"]))
triples.append((query.s.value, query.p.value, data["dest"]))
else:
if v.o is not None:
if query.o is not None:
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN rel.uri as rel",
src=v.s.value, value=v.o.value,
src=query.s.value, value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
triples.append((query.s.value, data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN rel.uri as rel",
src=v.s.value, uri=v.o.value,
src=query.s.value, uri=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], v.o.value))
triples.append((query.s.value, data["rel"], query.o.value))
else:
@ -158,55 +139,55 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
"RETURN rel.uri as rel, dest.value as dest",
src=v.s.value,
src=query.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
triples.append((query.s.value, data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
"RETURN rel.uri as rel, dest.uri as dest",
src=v.s.value,
src=query.s.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((v.s.value, data["rel"], data["dest"]))
triples.append((query.s.value, data["rel"], data["dest"]))
else:
if v.p is not None:
if query.p is not None:
if v.o is not None:
if query.o is not None:
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
"RETURN src.uri as src",
uri=v.p.value, value=v.o.value,
uri=query.p.value, value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, v.o.value))
triples.append((data["src"], query.p.value, query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
"RETURN src.uri as src",
uri=v.p.value, dest=v.o.value,
uri=query.p.value, dest=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, v.o.value))
triples.append((data["src"], query.p.value, query.o.value))
else:
@ -215,52 +196,52 @@ class Processor(ConsumerProducer):
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
"RETURN src.uri as src, dest.value as dest",
uri=v.p.value,
uri=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
triples.append((data["src"], query.p.value, data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
"RETURN src.uri as src, dest.uri as dest",
uri=v.p.value,
uri=query.p.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], v.p.value, data["dest"]))
triples.append((data["src"], query.p.value, data["dest"]))
else:
if v.o is not None:
if query.o is not None:
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
"RETURN src.uri as src, rel.uri as rel",
value=v.o.value,
value=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
triples.append((data["src"], data["rel"], query.o.value))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
"RETURN src.uri as src, rel.uri as rel",
uri=v.o.value,
uri=query.o.value,
database_=self.db,
)
for rec in records:
data = rec.data()
triples.append((data["src"], data["rel"], v.o.value))
triples.append((data["src"], data["rel"], query.o.value))
else:
@ -295,37 +276,17 @@ class Processor(ConsumerProducer):
for t in triples
]
print("Send response...", flush=True)
r = TriplesQueryResponse(triples=triples, error=None)
await self.send(r, properties={"id": id})
print("Done.", flush=True)
return triples
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TriplesQueryResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
await self.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
raise e
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
TriplesQueryService.add_args(parser)
parser.add_argument(
'-g', '--graph-host',
@ -353,5 +314,5 @@ class Processor(ConsumerProducer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -4,58 +4,41 @@ Accepts entity/vector pairs and writes them to a Milvus store.
"""
from .... direct.milvus_doc_embeddings import DocVectors
from .... base import DocumentEmbeddingsStoreService
from .... schema import DocumentEmbeddings
from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "de-write"
default_input_queue = document_embeddings_store_queue
default_subscriber = module
default_ident = "de-write"
default_store_uri = 'http://localhost:19530'
class Processor(Consumer):
class Processor(DocumentEmbeddingsStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddings,
"store_uri": store_uri,
}
)
self.vecstore = DocVectors(store_uri)
async def handle(self, msg):
async def store_document_embeddings(self, message):
v = msg.value()
for emb in v.chunks:
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "" or chunk is None: continue
if chunk == "": continue
for vec in emb.vectors:
if chunk != "" and v.chunk is not None:
for vec in v.vectors:
self.vecstore.insert(vec, chunk)
self.vecstore.insert(vec, chunk)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
DocumentEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
@ -65,5 +48,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -1,42 +1,32 @@
"""
Accepts entity/vector pairs and writes them to a Qdrant store.
Accepts document chunks/vector pairs and writes them to a Pinecone store.
"""
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import time
import uuid
import os
from .... schema import DocumentEmbeddings
from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from .... base import DocumentEmbeddingsStoreService
module = "de-write"
default_input_queue = document_embeddings_store_queue
default_subscriber = module
default_ident = "de-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"
class Processor(Consumer):
class Processor(DocumentEmbeddingsStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.cloud = params.get("cloud", default_cloud)
self.region = params.get("region", default_region)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None:
if self.api_key is None or self.api_key == "not-specified":
raise RuntimeError("Pinecone API key must be specified")
if self.url:
@ -52,94 +42,96 @@ class Processor(Consumer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddings,
"url": self.url,
"cloud": self.cloud,
"region": self.region,
"api_key": self.api_key,
}
)
self.last_index_name = None
async def handle(self, msg):
def create_index(self, index_name, dim):
v = msg.value()
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for emb in v.chunks:
for i in range(0, 1000):
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
async def store_document_embeddings(self, message):
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk = emb.chunk.decode("utf-8")
if chunk == "" or chunk is None: continue
if chunk == "": continue
for vec in emb.vectors:
for vec in v.vectors:
dim = len(vec)
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
)
dim = len(vec)
collection = (
"d-" + v.metadata.user + "-" + str(dim)
)
if index_name != self.last_index_name:
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
if not self.pinecone.has_index(index_name):
try:
try:
self.create_index(index_name, dim)
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
except Exception as e:
print("Pinecone index creation failed")
raise e
for i in range(0, 1000):
print(f"Index {index_name} created", flush=True)
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
self.last_index_name = index_name
time.sleep(1)
index = self.pinecone.Index(index_name)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
except Exception as e:
print("Pinecone index creation failed")
raise e
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "doc": chunk },
}
]
print(f"Index {index_name} created", flush=True)
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
records = [
{
"id": id,
"values": vec,
"metadata": { "doc": chunk },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
DocumentEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-a', '--api-key',
@ -166,5 +158,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -3,42 +3,29 @@
Accepts entity/vector pairs and writes them to a Milvus store.
"""
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import Consumer
from .... base import GraphEmbeddingsStoreService
module = "ge-write"
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_ident = "ge-write"
default_store_uri = 'http://localhost:19530'
class Processor(Consumer):
class Processor(GraphEmbeddingsStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
"store_uri": store_uri,
}
)
self.vecstore = EntityVectors(store_uri)
async def handle(self, msg):
async def store_graph_embeddings(self, message):
v = msg.value()
for entity in v.entities:
for entity in message.entities:
if entity.entity.value != "" and entity.entity.value is not None:
for vec in entity.vectors:
@ -47,9 +34,7 @@ class Processor(Consumer):
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
GraphEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
@ -59,5 +44,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -10,32 +10,23 @@ import time
import uuid
import os
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from .... base import GraphEmbeddingsStoreService
module = "ge-write"
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_ident = "ge-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"
class Processor(Consumer):
class Processor(GraphEmbeddingsStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.cloud = params.get("cloud", default_cloud)
self.region = params.get("region", default_region)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None:
if self.api_key is None or self.api_key == "not-specified":
raise RuntimeError("Pinecone API key must be specified")
if self.url:
@ -51,10 +42,10 @@ class Processor(Consumer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
"url": self.url,
"cloud": self.cloud,
"region": self.region,
"api_key": self.api_key,
}
)
@ -88,13 +79,9 @@ class Processor(Consumer):
"Gave up waiting for index creation"
)
async def handle(self, msg):
async def store_graph_embeddings(self, message):
v = msg.value()
id = str(uuid.uuid4())
for entity in v.entities:
for entity in message.entities:
if entity.entity.value == "" or entity.entity.value is None:
continue
@ -104,7 +91,7 @@ class Processor(Consumer):
dim = len(vec)
index_name = (
"t-" + v.metadata.user + "-" + str(dim)
"t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
)
if index_name != self.last_index_name:
@ -125,9 +112,12 @@ class Processor(Consumer):
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
records = [
{
"id": id,
"id": vector_id,
"values": vec,
"metadata": { "entity": entity.entity.value },
}
@ -135,15 +125,12 @@ class Processor(Consumer):
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
GraphEmbeddingsStoreService.add_args(parser)
parser.add_argument(
'-a', '--api-key',
@ -170,5 +157,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -11,34 +11,24 @@ import time
from falkordb import FalkorDB
from .... schema import Triples
from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from .... base import TriplesStoreService
module = "triples-write"
default_input_queue = triples_store_queue
default_subscriber = module
default_ident = "triples-write"
default_graph_url = 'falkor://falkordb:6379'
default_database = 'falkordb'
class Processor(Consumer):
class Processor(TriplesStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_url = params.get("graph_host", default_graph_url)
graph_url = params.get("graph_url", default_graph_url)
database = params.get("database", default_database)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triples,
"graph_url": graph_url,
"database": database,
}
)
@ -118,11 +108,9 @@ class Processor(Consumer):
time=res.run_time_ms
))
async def handle(self, msg):
async def store_triples(self, message):
v = msg.value()
for t in v.triples:
for t in message.triples:
self.create_node(t.s.value)
@ -136,14 +124,12 @@ class Processor(Consumer):
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
TriplesStoreService.add_args(parser)
parser.add_argument(
'-g', '--graph_host',
'-g', '--graph-url',
default=default_graph_url,
help=f'Graph host (default: {default_graph_url})'
help=f'Graph URL (default: {default_graph_url})'
)
parser.add_argument(
@ -154,5 +140,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -11,27 +11,19 @@ import time
from neo4j import GraphDatabase
from .... schema import Triples
from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
from .... base import TriplesStoreService
module = "triples-write"
default_input_queue = triples_store_queue
default_subscriber = module
default_ident = "triples-write"
default_graph_host = 'bolt://memgraph:7687'
default_username = 'memgraph'
default_password = 'password'
default_database = 'memgraph'
class Processor(Consumer):
class Processor(TriplesStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
username = params.get("username", default_username)
password = params.get("password", default_password)
@ -39,10 +31,10 @@ class Processor(Consumer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triples,
"graph_host": graph_host,
"username": username,
"password": password,
"database": database,
}
)
@ -205,11 +197,9 @@ class Processor(Consumer):
src=t.s.value, dest=t.o.value, uri=t.p.value,
)
async def handle(self, msg):
async def store_triples(self, message):
v = msg.value()
for t in v.triples:
for t in message.triples:
# self.create_node(t.s.value)
@ -226,12 +216,10 @@ class Processor(Consumer):
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
TriplesStoreService.add_args(parser)
parser.add_argument(
'-g', '--graph_host',
'-g', '--graph-host',
default=default_graph_host,
help=f'Graph host (default: {default_graph_host})'
)
@ -256,5 +244,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)

View file

@ -10,28 +10,21 @@ import argparse
import time
from neo4j import GraphDatabase
from .... base import TriplesStoreService
from .... schema import Triples
from .... schema import triples_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = "triples-write"
default_input_queue = triples_store_queue
default_subscriber = module
default_ident = "triples-write"
default_graph_host = 'bolt://neo4j:7687'
default_username = 'neo4j'
default_password = 'password'
default_database = 'neo4j'
class Processor(Consumer):
class Processor(TriplesStoreService):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
graph_host = params.get("graph_host", default_graph_host)
username = params.get("username", default_username)
password = params.get("password", default_password)
@ -39,10 +32,9 @@ class Processor(Consumer):
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triples,
"graph_host": graph_host,
"username": username,
"database": database,
}
)
@ -158,11 +150,9 @@ class Processor(Consumer):
time=summary.result_available_after
))
async def handle(self, msg):
async def store_triples(self, message):
v = msg.value()
for t in v.triples:
for t in message.triples:
self.create_node(t.s.value)
@ -176,9 +166,7 @@ class Processor(Consumer):
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
TriplesStoreService.add_args(parser)
parser.add_argument(
'-g', '--graph_host',
@ -206,5 +194,5 @@ class Processor(Consumer):
def run():
Processor.launch(module, __doc__)
Processor.launch(default_ident, __doc__)