mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Increase storage test coverage (#435)
* Fixing storage and adding tests * PR pipeline only runs quick tests
This commit is contained in:
parent
4daa54abaf
commit
f37decea2b
33 changed files with 7606 additions and 754 deletions
4
.github/workflows/pull-request.yaml
vendored
4
.github/workflows/pull-request.yaml
vendored
|
|
@ -48,8 +48,8 @@ jobs:
|
|||
- name: Unit tests
|
||||
run: pytest tests/unit
|
||||
|
||||
- name: Integration tests
|
||||
run: pytest tests/integration
|
||||
- name: Integration tests (cut the out the long-running tests)
|
||||
run: pytest tests/integration -m 'not slow'
|
||||
|
||||
- name: Contract tests
|
||||
run: pytest tests/contract
|
||||
|
|
|
|||
112
tests/integration/cassandra_test_helper.py
Normal file
112
tests/integration/cassandra_test_helper.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
Helper for managing Cassandra containers in integration tests
|
||||
Alternative to testcontainers for Fedora/Podman compatibility
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
import socket
|
||||
from contextlib import contextmanager
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import RetryPolicy
|
||||
|
||||
|
||||
class CassandraTestContainer:
|
||||
"""Simple Cassandra container manager using Podman"""
|
||||
|
||||
def __init__(self, image="docker.io/library/cassandra:4.1", port=9042):
|
||||
self.image = image
|
||||
self.port = port
|
||||
self.container_name = f"test-cassandra-{int(time.time())}"
|
||||
self.container_id = None
|
||||
|
||||
def start(self):
|
||||
"""Start Cassandra container"""
|
||||
# Remove any existing container with same name
|
||||
subprocess.run([
|
||||
"podman", "rm", "-f", self.container_name
|
||||
], capture_output=True)
|
||||
|
||||
# Start new container with faster startup options
|
||||
result = subprocess.run([
|
||||
"podman", "run", "-d",
|
||||
"--name", self.container_name,
|
||||
"-p", f"{self.port}:9042",
|
||||
"-e", "JVM_OPTS=-Dcassandra.skip_wait_for_gossip_to_settle=0",
|
||||
self.image
|
||||
], capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start container: {result.stderr}")
|
||||
|
||||
self.container_id = result.stdout.strip()
|
||||
|
||||
# Wait for Cassandra to be ready
|
||||
self._wait_for_ready()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
"""Stop and remove container"""
|
||||
import time
|
||||
if self.container_name:
|
||||
# Small delay before stopping to ensure connections are closed
|
||||
time.sleep(0.5)
|
||||
subprocess.run([
|
||||
"podman", "rm", "-f", self.container_name
|
||||
], capture_output=True)
|
||||
|
||||
def get_connection_host_port(self):
|
||||
"""Get host and port for connection"""
|
||||
return "localhost", self.port
|
||||
|
||||
def _wait_for_ready(self, timeout=120):
|
||||
"""Wait for Cassandra to be ready for CQL queries"""
|
||||
start_time = time.time()
|
||||
|
||||
print(f"Waiting for Cassandra to be ready on port {self.port}...")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# First check if port is open
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
result = sock.connect_ex(("localhost", self.port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
# Port is open, now try to connect with Cassandra driver
|
||||
try:
|
||||
cluster = Cluster(['localhost'], port=self.port)
|
||||
cluster.connect_timeout = 5
|
||||
session = cluster.connect()
|
||||
|
||||
# Try a simple query to verify Cassandra is ready
|
||||
session.execute("SELECT release_version FROM system.local")
|
||||
session.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
print("Cassandra is ready!")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
print(f"Cassandra not ready yet: {e}")
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f"Connection check failed: {e}")
|
||||
pass
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
raise RuntimeError(f"Cassandra not ready after {timeout} seconds")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cassandra_container(image="docker.io/library/cassandra:4.1", port=9042):
|
||||
"""Context manager for Cassandra container"""
|
||||
container = CassandraTestContainer(image, port)
|
||||
try:
|
||||
container.start()
|
||||
yield container
|
||||
finally:
|
||||
container.stop()
|
||||
|
|
@ -383,4 +383,22 @@ def sample_kg_triples():
|
|||
|
||||
|
||||
# Test markers for integration tests
|
||||
pytestmark = pytest.mark.integration
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
"""
|
||||
Called after whole test run finished, right before returning the exit status.
|
||||
|
||||
This hook is used to ensure Cassandra driver threads have time to shut down
|
||||
properly before pytest exits, preventing "cannot schedule new futures after
|
||||
shutdown" errors.
|
||||
"""
|
||||
import time
|
||||
import gc
|
||||
|
||||
# Force garbage collection to clean up any remaining objects
|
||||
gc.collect()
|
||||
|
||||
# Give Cassandra driver threads more time to clean up
|
||||
time.sleep(2)
|
||||
411
tests/integration/test_cassandra_integration.py
Normal file
411
tests/integration/test_cassandra_integration.py
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
"""
|
||||
Cassandra integration tests using Podman containers
|
||||
|
||||
These tests verify end-to-end functionality of Cassandra storage and query processors
|
||||
with real database instances. Compatible with Fedora Linux and Podman.
|
||||
|
||||
Uses a single container for all tests to minimize startup time.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from .cassandra_test_helper import cassandra_container
|
||||
from trustgraph.direct.cassandra import TrustGraph
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
class TestCassandraIntegration:
|
||||
"""Integration tests for Cassandra using a single shared container"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cassandra_shared_container(self):
|
||||
"""Class-level fixture: single Cassandra container for all tests"""
|
||||
with cassandra_container() as container:
|
||||
yield container
|
||||
|
||||
def setup_method(self):
|
||||
"""Track all created clients for cleanup"""
|
||||
self.clients_to_close = []
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up all Cassandra connections"""
|
||||
import gc
|
||||
|
||||
for client in self.clients_to_close:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Clear the list and force garbage collection
|
||||
self.clients_to_close.clear()
|
||||
gc.collect()
|
||||
|
||||
# Small delay to let threads finish
|
||||
time.sleep(0.5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_cassandra_integration(self, cassandra_shared_container):
|
||||
"""Complete integration test covering all Cassandra functionality"""
|
||||
container = cassandra_shared_container
|
||||
host, port = container.get_connection_host_port()
|
||||
|
||||
print("=" * 60)
|
||||
print("RUNNING COMPLETE CASSANDRA INTEGRATION TEST")
|
||||
print("=" * 60)
|
||||
|
||||
# =====================================================
|
||||
# Test 1: Basic TrustGraph Operations
|
||||
# =====================================================
|
||||
print("\n1. Testing basic TrustGraph operations...")
|
||||
|
||||
client = TrustGraph(
|
||||
hosts=[host],
|
||||
keyspace="test_basic",
|
||||
table="test_table"
|
||||
)
|
||||
self.clients_to_close.append(client)
|
||||
|
||||
# Insert test data
|
||||
client.insert("http://example.org/alice", "knows", "http://example.org/bob")
|
||||
client.insert("http://example.org/alice", "age", "25")
|
||||
client.insert("http://example.org/bob", "age", "30")
|
||||
|
||||
# Test get_all
|
||||
all_results = list(client.get_all(limit=10))
|
||||
assert len(all_results) == 3
|
||||
print(f"✓ Stored and retrieved {len(all_results)} triples")
|
||||
|
||||
# Test get_s (subject query)
|
||||
alice_results = list(client.get_s("http://example.org/alice", limit=10))
|
||||
assert len(alice_results) == 2
|
||||
alice_predicates = [r.p for r in alice_results]
|
||||
assert "knows" in alice_predicates
|
||||
assert "age" in alice_predicates
|
||||
print("✓ Subject queries working")
|
||||
|
||||
# Test get_p (predicate query)
|
||||
age_results = list(client.get_p("age", limit=10))
|
||||
assert len(age_results) == 2
|
||||
age_subjects = [r.s for r in age_results]
|
||||
assert "http://example.org/alice" in age_subjects
|
||||
assert "http://example.org/bob" in age_subjects
|
||||
print("✓ Predicate queries working")
|
||||
|
||||
# =====================================================
|
||||
# Test 2: Storage Processor Integration
|
||||
# =====================================================
|
||||
print("\n2. Testing storage processor integration...")
|
||||
|
||||
storage_processor = StorageProcessor(
|
||||
taskgroup=MagicMock(),
|
||||
hosts=[host],
|
||||
keyspace="test_storage",
|
||||
table="test_triples"
|
||||
)
|
||||
# Track the TrustGraph instance that will be created
|
||||
self.storage_processor = storage_processor
|
||||
|
||||
# Create test message
|
||||
storage_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value="Alice Smith", is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="25", is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value="Engineering", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Store triples via processor
|
||||
await storage_processor.store_triples(storage_message)
|
||||
# Track the created TrustGraph instance
|
||||
if hasattr(storage_processor, 'tg'):
|
||||
self.clients_to_close.append(storage_processor.tg)
|
||||
|
||||
# Verify data was stored
|
||||
storage_results = list(storage_processor.tg.get_s("http://example.org/person1", limit=10))
|
||||
assert len(storage_results) == 3
|
||||
|
||||
predicates = [row.p for row in storage_results]
|
||||
objects = [row.o for row in storage_results]
|
||||
|
||||
assert "http://example.org/name" in predicates
|
||||
assert "http://example.org/age" in predicates
|
||||
assert "http://example.org/department" in predicates
|
||||
assert "Alice Smith" in objects
|
||||
assert "25" in objects
|
||||
assert "Engineering" in objects
|
||||
print("✓ Storage processor working")
|
||||
|
||||
# =====================================================
|
||||
# Test 3: Query Processor Integration
|
||||
# =====================================================
|
||||
print("\n3. Testing query processor integration...")
|
||||
|
||||
query_processor = QueryProcessor(
|
||||
taskgroup=MagicMock(),
|
||||
hosts=[host],
|
||||
keyspace="test_query",
|
||||
table="test_triples"
|
||||
)
|
||||
|
||||
# Use same storage processor for the query keyspace
|
||||
query_storage_processor = StorageProcessor(
|
||||
taskgroup=MagicMock(),
|
||||
hosts=[host],
|
||||
keyspace="test_query",
|
||||
table="test_triples"
|
||||
)
|
||||
|
||||
# Store test data for querying
|
||||
query_test_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/bob", is_uri=True)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="30", is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/bob", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/charlie", is_uri=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
await query_storage_processor.store_triples(query_test_message)
|
||||
|
||||
# Debug: Check what was actually stored
|
||||
print("Debug: Checking what was stored for Alice...")
|
||||
direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10))
|
||||
print(f"Direct TrustGraph results: {len(direct_results)}")
|
||||
for result in direct_results:
|
||||
print(f" S=http://example.org/alice, P={result.p}, O={result.o}")
|
||||
|
||||
# Test S query (find all relationships for Alice)
|
||||
s_query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
s_results = await query_processor.query_triples(s_query)
|
||||
print(f"Query processor results: {len(s_results)}")
|
||||
for result in s_results:
|
||||
print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}")
|
||||
assert len(s_results) == 2
|
||||
|
||||
s_predicates = [t.p.value for t in s_results]
|
||||
assert "http://example.org/knows" in s_predicates
|
||||
assert "http://example.org/age" in s_predicates
|
||||
print("✓ Subject queries via processor working")
|
||||
|
||||
# Test P query (find all "knows" relationships)
|
||||
p_query = TriplesQueryRequest(
|
||||
s=None, # None for wildcard
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
p_results = await query_processor.query_triples(p_query)
|
||||
print(p_results)
|
||||
assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie
|
||||
|
||||
p_subjects = [t.s.value for t in p_results]
|
||||
assert "http://example.org/alice" in p_subjects
|
||||
assert "http://example.org/bob" in p_subjects
|
||||
print("✓ Predicate queries via processor working")
|
||||
|
||||
# =====================================================
|
||||
# Test 4: Concurrent Operations
|
||||
# =====================================================
|
||||
print("\n4. Testing concurrent operations...")
|
||||
|
||||
concurrent_processor = StorageProcessor(
|
||||
taskgroup=MagicMock(),
|
||||
hosts=[host],
|
||||
keyspace="test_concurrent",
|
||||
table="test_triples"
|
||||
)
|
||||
|
||||
# Create multiple coroutines for concurrent storage
|
||||
async def store_person_data(person_id, name, age, department):
|
||||
message = Triples(
|
||||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value=name, is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value=str(age), is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value=department, is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
await concurrent_processor.store_triples(message)
|
||||
|
||||
# Store data for multiple people concurrently
|
||||
people_data = [
|
||||
("person1", "John Doe", 25, "Engineering"),
|
||||
("person2", "Jane Smith", 30, "Marketing"),
|
||||
("person3", "Bob Wilson", 35, "Engineering"),
|
||||
("person4", "Alice Brown", 28, "Sales"),
|
||||
]
|
||||
|
||||
# Run storage operations concurrently
|
||||
store_tasks = [store_person_data(pid, name, age, dept) for pid, name, age, dept in people_data]
|
||||
await asyncio.gather(*store_tasks)
|
||||
# Track the created TrustGraph instance
|
||||
if hasattr(concurrent_processor, 'tg'):
|
||||
self.clients_to_close.append(concurrent_processor.tg)
|
||||
|
||||
# Verify all names were stored
|
||||
name_results = list(concurrent_processor.tg.get_p("http://example.org/name", limit=10))
|
||||
assert len(name_results) == 4
|
||||
|
||||
stored_names = [r.o for r in name_results]
|
||||
expected_names = ["John Doe", "Jane Smith", "Bob Wilson", "Alice Brown"]
|
||||
|
||||
for name in expected_names:
|
||||
assert name in stored_names
|
||||
|
||||
# Verify department data
|
||||
dept_results = list(concurrent_processor.tg.get_p("http://example.org/department", limit=10))
|
||||
assert len(dept_results) == 4
|
||||
|
||||
stored_depts = [r.o for r in dept_results]
|
||||
assert "Engineering" in stored_depts
|
||||
assert "Marketing" in stored_depts
|
||||
assert "Sales" in stored_depts
|
||||
print("✓ Concurrent operations working")
|
||||
|
||||
# =====================================================
|
||||
# Test 5: Complex Queries and Data Integrity
|
||||
# =====================================================
|
||||
print("\n5. Testing complex queries and data integrity...")
|
||||
|
||||
complex_processor = StorageProcessor(
|
||||
taskgroup=MagicMock(),
|
||||
hosts=[host],
|
||||
keyspace="test_complex",
|
||||
table="test_triples"
|
||||
)
|
||||
|
||||
# Create a knowledge graph about a company
|
||||
company_graph = Triples(
|
||||
metadata=Metadata(user="integration_test", collection="company"),
|
||||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/bob", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
),
|
||||
# Relationships
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/reportsTo", is_uri=True),
|
||||
o=Value(value="http://company.org/bob", is_uri=True)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/worksIn", is_uri=True),
|
||||
o=Value(value="http://company.org/engineering", is_uri=True)
|
||||
),
|
||||
# Personal info
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/fullName", is_uri=True),
|
||||
o=Value(value="Alice Johnson", is_uri=False)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/email", is_uri=True),
|
||||
o=Value(value="alice@company.org", is_uri=False)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Store the company knowledge graph
|
||||
await complex_processor.store_triples(company_graph)
|
||||
# Track the created TrustGraph instance
|
||||
if hasattr(complex_processor, 'tg'):
|
||||
self.clients_to_close.append(complex_processor.tg)
|
||||
|
||||
# Verify all Alice's data
|
||||
alice_data = list(complex_processor.tg.get_s("http://company.org/alice", limit=20))
|
||||
assert len(alice_data) == 5
|
||||
|
||||
alice_predicates = [r.p for r in alice_data]
|
||||
expected_predicates = [
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"http://company.org/reportsTo",
|
||||
"http://company.org/worksIn",
|
||||
"http://company.org/fullName",
|
||||
"http://company.org/email"
|
||||
]
|
||||
for pred in expected_predicates:
|
||||
assert pred in alice_predicates
|
||||
|
||||
# Test type-based queries
|
||||
employee_results = list(complex_processor.tg.get_p("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", limit=10))
|
||||
print(employee_results)
|
||||
assert len(employee_results) == 2
|
||||
|
||||
employees = [r.s for r in employee_results]
|
||||
assert "http://company.org/alice" in employees
|
||||
assert "http://company.org/bob" in employees
|
||||
print("✓ Complex queries and data integrity working")
|
||||
|
||||
# =====================================================
|
||||
# Summary
|
||||
# =====================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ ALL CASSANDRA INTEGRATION TESTS PASSED!")
|
||||
print("✅ Basic operations: PASSED")
|
||||
print("✅ Storage processor: PASSED")
|
||||
print("✅ Query processor: PASSED")
|
||||
print("✅ Concurrent operations: PASSED")
|
||||
print("✅ Complex queries: PASSED")
|
||||
print("=" * 60)
|
||||
456
tests/unit/test_query/test_doc_embeddings_milvus_query.py
Normal file
456
tests/unit/test_query/test_doc_embeddings_milvus_query.py
Normal file
|
|
@ -0,0 +1,456 @@
|
|||
"""
|
||||
Tests for Milvus document embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsQueryProcessor:
|
||||
"""Test cases for Milvus document embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') as mock_doc_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-de-query',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"doc": "First document chunk"}},
|
||||
{"entity": {"doc": "Second document chunk"}},
|
||||
{"entity": {"doc": "Third document chunk"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
|
||||
|
||||
# Verify results are document chunks
|
||||
assert len(result) == 3
|
||||
assert result[0] == "First document chunk"
|
||||
assert result[1] == "Second document chunk"
|
||||
assert result[2] == "Third document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document from first vector"}},
|
||||
{"entity": {"doc": "Another doc from first vector"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document from second vector"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3],), {"limit": 3}),
|
||||
(([0.4, 0.5, 0.6],), {"limit": 3}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results from all vectors are combined
|
||||
assert len(result) == 3
|
||||
assert "Document from first vector" in result
|
||||
assert "Another doc from first vector" in result
|
||||
assert "Document from second vector" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document 1"}},
|
||||
{"entity": {"doc": "Document 2"}},
|
||||
{"entity": {"doc": "Document 3"}},
|
||||
{"entity": {"doc": "Document 4"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called with the specified limit
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2)
|
||||
|
||||
# Verify all results are returned (Milvus handles limit internally)
|
||||
assert len(result) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying document embeddings with empty vectors list"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying document embeddings with empty search results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_unicode_documents(self, processor):
|
||||
"""Test querying document embeddings with Unicode document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with Unicode content
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
|
||||
{"entity": {"doc": "Regular ASCII document"}},
|
||||
{"entity": {"doc": "Document with émojis: 😀🎉"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify Unicode content is preserved
|
||||
assert len(result) == 3
|
||||
assert "Document with Unicode: éñ中文🚀" in result
|
||||
assert "Regular ASCII document" in result
|
||||
assert "Document with émojis: 😀🎉" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
"""Test querying document embeddings with large document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with large content
|
||||
large_doc = "A" * 10000 # 10KB of content
|
||||
mock_results = [
|
||||
{"entity": {"doc": large_doc}},
|
||||
{"entity": {"doc": "Small document"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify large content is preserved
|
||||
assert len(result) == 2
|
||||
assert large_doc in result
|
||||
assert "Small document" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
"""Test querying document embeddings with special characters in documents"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with special characters
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
|
||||
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
|
||||
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify special characters are preserved
|
||||
assert len(result) == 3
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in result
|
||||
assert "Document with\nnewlines\tand\ttabs" in result
|
||||
assert "Document with special chars: @#$%^&*()" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to zero limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying document embeddings with negative limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for negative limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to negative limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search to raise exception
|
||||
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_document_embeddings(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
|
||||
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
|
||||
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
assert "Document from 2D vector" in result
|
||||
assert "Document from 4D vector" in result
|
||||
assert "Document from 3D vector" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_duplicate_documents(self, processor):
|
||||
"""Test querying document embeddings with duplicate documents in results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates across vectors
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document A"}},
|
||||
{"entity": {"doc": "Document B"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document B"}}, # Duplicate
|
||||
{"entity": {"doc": "Document C"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
|
||||
# This preserves ranking and allows multiple occurrences
|
||||
assert len(result) == 4
|
||||
assert result.count("Document B") == 2 # Should appear twice
|
||||
assert "Document A" in result
|
||||
assert "Document C" in result
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.doc_embeddings.milvus.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
|
||||
)
|
||||
558
tests/unit/test_query/test_doc_embeddings_pinecone_query.py
Normal file
558
tests/unit/test_query/test_doc_embeddings_pinecone_query.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""
|
||||
Tests for Pinecone document embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.pinecone.service import Processor
|
||||
|
||||
|
||||
class TestPineconeDocEmbeddingsQueryProcessor:
|
||||
"""Test cases for Pinecone document embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-de-query',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'First document chunk'}),
|
||||
MagicMock(metadata={'doc': 'Second document chunk'}),
|
||||
MagicMock(metadata={'doc': 'Third document chunk'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
mock_index.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
top_k=3,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0] == 'First document chunk'
|
||||
assert chunks[1] == 'Second document chunk'
|
||||
assert chunks[2] == 'Third document chunk'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'doc': 'Document chunk 1'}),
|
||||
MagicMock(metadata={'doc': 'Document chunk 2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'doc': 'Document chunk 3'}),
|
||||
MagicMock(metadata={'doc': 'Document chunk 4'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
chunks = await processor.query_document_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify results from both queries
|
||||
assert len(chunks) == 4
|
||||
assert 'Document chunk 1' in chunks
|
||||
assert 'Document chunk 2' in chunks
|
||||
assert 'Document chunk 3' in chunks
|
||||
assert 'Document chunk 4' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index with many results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': f'Document chunk {i}'}) for i in range(10)
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify limit is passed to query
|
||||
mock_index.query.assert_called_once()
|
||||
call_args = mock_index.query.call_args
|
||||
assert call_args[1]['top_k'] == 2
|
||||
|
||||
# Results should contain all returned chunks (limit is applied by Pinecone)
|
||||
assert len(chunks) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
|
||||
# Mock results for different dimensions
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
|
||||
mock_index_2d.query.return_value = mock_results_2d
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
|
||||
mock_index_4d.query.return_value = mock_results_4d
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify different indexes were used
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
mock_index_2d.query.assert_called_once()
|
||||
mock_index_4d.query.assert_called_once()
|
||||
|
||||
# Verify results from both dimensions
|
||||
assert 'Document from 2D index' in chunks
|
||||
assert 'Document from 4D index' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify empty results
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_unicode_content(self, processor):
|
||||
"""Test querying document embeddings with Unicode content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'Document with Unicode: éñ中文🚀'}),
|
||||
MagicMock(metadata={'doc': 'Regular ASCII document'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content is properly handled
|
||||
assert len(chunks) == 2
|
||||
assert 'Document with Unicode: éñ中文🚀' in chunks
|
||||
assert 'Regular ASCII document' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_content(self, processor):
|
||||
"""Test querying document embeddings with large content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Create a large document content
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': large_content})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify large content is properly handled
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == large_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
||||
"""Test querying document embeddings with mixed content types"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'Short text'}),
|
||||
MagicMock(metadata={'doc': 'A' * 1000}), # Long text
|
||||
MagicMock(metadata={'doc': 'Text with numbers: 123 and symbols: @#$'}),
|
||||
MagicMock(metadata={'doc': ' Whitespace text '}),
|
||||
MagicMock(metadata={'doc': ''}) # Empty string
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify all content types are properly handled
|
||||
assert len(chunks) == 5
|
||||
assert 'Short text' in chunks
|
||||
assert 'A' * 1000 in chunks
|
||||
assert 'Text with numbers: 123 and symbols: @#$' in chunks
|
||||
assert ' Whitespace text ' in chunks
|
||||
assert '' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
"""Test handling of index access failure"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
processor.pinecone.Index.side_effect = Exception("Index access failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index access failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
"""Test that results from multiple vectors are properly accumulated"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Each query returns different results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 1.1'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 1.2'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 2.1'})
|
||||
]
|
||||
|
||||
mock_results3 = MagicMock()
|
||||
mock_results3.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.1'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.2'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify all queries were made
|
||||
assert mock_index.query.call_count == 3
|
||||
|
||||
# Verify all results are accumulated
|
||||
assert len(chunks) == 6
|
||||
assert 'Doc from vector 1.1' in chunks
|
||||
assert 'Doc from vector 1.2' in chunks
|
||||
assert 'Doc from vector 2.1' in chunks
|
||||
assert 'Doc from vector 3.1' in chunks
|
||||
assert 'Doc from vector 3.2' in chunks
|
||||
assert 'Doc from vector 3.3' in chunks
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.doc_embeddings.pinecone.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks. Pinecone implementation.\n"
|
||||
)
|
||||
484
tests/unit/test_query/test_graph_embeddings_milvus_query.py
Normal file
484
tests/unit/test_query/test_graph_embeddings_milvus_query.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
"""
|
||||
Tests for Milvus graph embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Value, GraphEmbeddingsRequest
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
"""Test cases for Milvus graph embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') as mock_entity_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-ge-query',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "literal entity"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
|
||||
|
||||
# Verify results are converted to Value objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Value)
|
||||
assert result[0].value == "http://example.com/entity1"
|
||||
assert result[0].is_uri is True
|
||||
assert isinstance(result[1], Value)
|
||||
assert result[1].value == "http://example.com/entity2"
|
||||
assert result[1].is_uri is True
|
||||
assert isinstance(result[2], Value)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3],), {"limit": 6}),
|
||||
(([0.4, 0.5, 0.6],), {"limit": 6}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results are deduplicated and limited
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_with_limit(self, processor):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
{"entity": {"entity": "http://example.com/entity4"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with 2*limit for better deduplication
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
|
||||
|
||||
# Verify results are limited to the requested limit
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication(self, processor):
|
||||
"""Test that duplicate entities are properly deduplicated"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}}, # New
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - first vector returns enough results
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results_1
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify only first vector was searched (limit reached)
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
|
||||
|
||||
# Verify results are limited
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying graph embeddings with empty vectors list"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying graph embeddings with empty search results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
|
||||
"""Test querying graph embeddings with mixed URI and literal results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with mixed types
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/uri_entity"}},
|
||||
{"entity": {"entity": "literal entity text"}},
|
||||
{"entity": {"entity": "https://example.com/another_uri"}},
|
||||
{"entity": {"entity": "another literal"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.is_uri]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.value for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.is_uri]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
assert "another literal" in literal_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search to raise exception
|
||||
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_graph_embeddings(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.graph_embeddings.milvus.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph embeddings query service. Input is vector, output is list of\nentities\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to zero limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
|
||||
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
|
||||
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
507
tests/unit/test_query/test_graph_embeddings_pinecone_query.py
Normal file
507
tests/unit/test_query/test_graph_embeddings_pinecone_query.py
Normal file
|
|
@ -0,0 +1,507 @@
|
|||
"""
|
||||
Tests for Pinecone graph embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
"""Test cases for Pinecone graph embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-ge-query',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
def test_create_value_uri(self, processor):
|
||||
"""Test create_value method for URI entities"""
|
||||
uri_entity = "http://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
|
||||
def test_create_value_https_uri(self, processor):
|
||||
"""Test create_value method for HTTPS URI entities"""
|
||||
uri_entity = "https://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
|
||||
def test_create_value_literal(self, processor):
|
||||
"""Test create_value method for literal entities"""
|
||||
literal_entity = "literal_entity"
|
||||
value = processor.create_value(literal_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == literal_entity
|
||||
assert value.is_uri == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'http://example.org/entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'http://example.org/entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
mock_index.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
top_k=6, # 2 * limit
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].is_uri == True
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].is_uri == False
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].is_uri == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify deduplication occurred
|
||||
entity_values = [e.value for e in entities]
|
||||
assert len(entity_values) == 3
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index with many results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': f'entity{i}'}) for i in range(10)
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify limit is respected
|
||||
assert len(entities) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
|
||||
# Mock results for different dimensions
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
mock_index_2d.query.return_value = mock_results_2d
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
||||
mock_index_4d.query.return_value = mock_results_4d
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes were used
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
mock_index_2d.query.assert_called_once()
|
||||
mock_index_4d.query.assert_called_once()
|
||||
|
||||
# Verify results from both dimensions
|
||||
entity_values = [e.value for e in entities]
|
||||
assert 'entity_2d' in entity_values
|
||||
assert 'entity_4d' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify empty results
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
||||
"""Test that deduplication works correctly across multiple vector queries"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Both queries return overlapping results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'}),
|
||||
MagicMock(metadata={'entity': 'entity4'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity5'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
entity_values = [e.value for e in entities]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query returns enough results to meet limit
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results1
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should only make one query since limit was reached
|
||||
mock_index.query.assert_called_once()
|
||||
assert len(entities) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph embeddings query service. Input is vector, output is list of\nentities. Pinecone implementation.\n"
|
||||
)
|
||||
556
tests/unit/test_query/test_triples_falkordb_query.py
Normal file
556
tests/unit/test_query/test_triples_falkordb_query.py
Normal file
|
|
@ -0,0 +1,556 @@
|
|||
"""
|
||||
Tests for FalkorDB triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.falkordb.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestFalkorDBQueryProcessor:
|
||||
"""Test cases for FalkorDB query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.falkordb.service.FalkorDB'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-falkordb-query',
|
||||
graph_url='falkor://localhost:6379'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'falkordb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
|
||||
mock_client.select_graph.assert_called_once_with('falkordb')
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_custom_params(self, mock_falkordb):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_url='falkor://custom:6379',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
|
||||
mock_client.select_graph.assert_called_once_with('customdb')
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_falkordb):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_result = MagicMock()
|
||||
mock_result.result_set = [["record1"]]
|
||||
mock_graph.query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_falkordb):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["literal result"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/uri_result"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_so_query(self, mock_falkordb):
|
||||
"""Test SO query (subject and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different predicates
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/pred1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/pred2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_s_query(self, mock_falkordb):
|
||||
"""Test S query (subject only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different predicate-object pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/pred1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/pred2", "http://example.com/uri2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_po_query(self, mock_falkordb):
|
||||
"""Test PO query (predicate and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subjects
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_p_query(self, mock_falkordb):
|
||||
"""Test P query (predicate only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subject-object pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/uri2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_o_query(self, mock_falkordb):
|
||||
"""Test O query (object only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subject-predicate pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1", "http://example.com/pred1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/pred2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_falkordb):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/s1", "http://example.com/p1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/s2", "http://example.com/p2", "http://example.com/o2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_falkordb):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query to raise exception
|
||||
mock_graph.query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_url')
|
||||
assert args.graph_url == 'falkor://falkordb:6379'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'falkordb'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-url', 'falkor://custom:6379',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_url == 'falkor://custom:6379'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'falkor://short:6379'])
|
||||
|
||||
assert args.graph_url == 'falkor://short:6379'
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.falkordb.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for FalkorDB.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
568
tests/unit/test_query/test_triples_memgraph_query.py
Normal file
568
tests/unit/test_query/test_triples_memgraph_query.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""
|
||||
Tests for Memgraph triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestMemgraphQueryProcessor:
|
||||
"""Test cases for Memgraph query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.memgraph.service.GraphDatabase'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-memgraph-query',
|
||||
graph_host='bolt://localhost:7687'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'memgraph'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://memgraph:7687',
|
||||
auth=('memgraph', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='queryuser',
|
||||
password='querypass',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('queryuser', 'querypass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_graph_db):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_records = [MagicMock()]
|
||||
mock_driver.execute_query.return_value = (mock_records, None, None)
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_graph_db):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"dest": "literal result"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_so_query(self, mock_graph_db):
|
||||
"""Test SO query (subject and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different predicates
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"rel": "http://example.com/pred1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"rel": "http://example.com/pred2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_s_query(self, mock_graph_db):
|
||||
"""Test S query (subject only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different predicate-object pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"rel": "http://example.com/pred1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"rel": "http://example.com/pred2", "dest": "http://example.com/uri2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_po_query(self, mock_graph_db):
|
||||
"""Test PO query (predicate and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subjects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_p_query(self, mock_graph_db):
|
||||
"""Test P query (predicate only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subject-object pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2", "dest": "http://example.com/uri2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_o_query(self, mock_graph_db):
|
||||
"""Test O query (object only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subject-predicate pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1", "rel": "http://example.com/pred1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2", "rel": "http://example.com/pred2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_graph_db):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_graph_db):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock execute_query to raise exception
|
||||
mock_driver.execute_query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'memgraph'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'memgraph'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'queryuser',
|
||||
'--password', 'querypass',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'queryuser'
|
||||
assert args.password == 'querypass'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.memgraph.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for memgraph.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
338
tests/unit/test_query/test_triples_neo4j_query.py
Normal file
338
tests/unit/test_query/test_triples_neo4j_query.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""
|
||||
Tests for Neo4j triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestNeo4jQueryProcessor:
|
||||
"""Test cases for Neo4j query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.neo4j.service.GraphDatabase'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-neo4j-query',
|
||||
graph_host='bolt://localhost:7687'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'neo4j'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://neo4j:7687',
|
||||
auth=('neo4j', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='queryuser',
|
||||
password='querypass',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('queryuser', 'querypass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_graph_db):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_records = [MagicMock()]
|
||||
mock_driver.execute_query.return_value = (mock_records, None, None)
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_graph_db):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"dest": "literal result"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_graph_db):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_graph_db):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock execute_query to raise exception
|
||||
mock_driver.execute_query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'neo4j'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'queryuser',
|
||||
'--password', 'querypass',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'queryuser'
|
||||
assert args.password == 'querypass'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.neo4j.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for neo4j.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
387
tests/unit/test_storage/test_doc_embeddings_milvus_storage.py
Normal file
387
tests/unit/test_storage/test_doc_embeddings_milvus_storage.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Tests for Milvus document embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.doc_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import ChunkEmbeddings
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsStorageProcessor:
|
||||
"""Test cases for Milvus document embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors') as mock_doc_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-de-storage',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.DocVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_single_chunk(self, processor):
|
||||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], "Test document content"),
|
||||
([0.4, 0.5, 0.6], "Test document content"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each chunk
|
||||
expected_calls = [
|
||||
# Chunk 1 vectors
|
||||
([0.1, 0.2, 0.3], "This is the first document chunk"),
|
||||
([0.4, 0.5, 0.6], "This is the first document chunk"),
|
||||
# Chunk 2 vectors
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called for empty chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_none_chunk(self, processor):
|
||||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called for None chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
|
||||
"""Test storing document embeddings with mix of valid and invalid chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_chunk = ChunkEmbeddings(
|
||||
chunk=b"Valid document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [valid_chunk, empty_chunk, none_chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify only valid chunk was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Valid document content"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
|
||||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with no vectors",
|
||||
vectors=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
([0.1, 0.2], "Document with mixed dimensions"),
|
||||
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"),
|
||||
([0.7, 0.8, 0.9], "Document with mixed dimensions"),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_doc) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content was properly decoded and inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_large_chunks(self, processor):
|
||||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=large_content.encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify large content was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], large_content
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_whitespace_only_chunk(self, processor):
|
||||
"""Test storing document embeddings with whitespace-only chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b" \n\t ",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify whitespace content was inserted (not filtered out)
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], " \n\t "
|
||||
)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.milvus.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.milvus.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.doc_embeddings.milvus.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
|
||||
)
|
||||
536
tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py
Normal file
536
tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
"""
|
||||
Tests for Pinecone document embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.doc_embeddings.pinecone.write import Processor
|
||||
from trustgraph.schema import ChunkEmbeddings
|
||||
|
||||
|
||||
class TestPineconeDocEmbeddingsStorageProcessor:
|
||||
"""Test cases for Pinecone document embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-de-storage',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
assert processor.cloud == 'aws'
|
||||
assert processor.region == 'us-east-1'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key',
|
||||
cloud='gcp',
|
||||
region='us-west1'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
assert processor.cloud == 'gcp'
|
||||
assert processor.region == 'us-west1'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_single_chunk(self, processor):
|
||||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
assert mock_index.upsert.call_count == 2
|
||||
|
||||
# Check first vector upsert
|
||||
first_call = mock_index.upsert.call_args_list[0]
|
||||
first_vectors = first_call[1]['vectors']
|
||||
assert len(first_vectors) == 1
|
||||
assert first_vectors[0]['id'] == 'id1'
|
||||
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
|
||||
assert first_vectors[0]['metadata']['doc'] == "Test document content"
|
||||
|
||||
# Check second vector upsert
|
||||
second_call = mock_index.upsert.call_args_list[1]
|
||||
second_vectors = second_call[1]['vectors']
|
||||
assert len(second_vectors) == 1
|
||||
assert second_vectors[0]['id'] == 'id2'
|
||||
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
|
||||
assert second_vectors[0]['metadata']['doc'] == "Test document content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Verify upsert was called for each vector (3 total)
|
||||
assert mock_index.upsert.call_count == 3
|
||||
|
||||
# Verify document content in metadata
|
||||
calls = mock_index.upsert.call_args_list
|
||||
assert calls[0][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
|
||||
assert calls[1][1]['vectors'][0]['metadata']['doc'] == "This is the first document chunk"
|
||||
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_none_chunk(self, processor):
|
||||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for None chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_decoded_chunk(self, processor):
|
||||
"""Test storing document embeddings with chunk that decodes to empty string"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"", # Empty bytes
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty decoded chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
mock_index_3d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
elif name.endswith("-3"):
|
||||
return mock_index_3d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 3
|
||||
mock_index_2d.upsert.assert_called_once()
|
||||
mock_index_4d.upsert.assert_called_once()
|
||||
mock_index_3d.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_chunk_with_no_vectors(self, processor):
|
||||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with no vectors",
|
||||
vectors=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content was properly decoded and stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
|
||||
assert stored_doc == "Document with Unicode: éñ中文🚀"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_large_chunks(self, processor):
|
||||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=large_content.encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify large content was stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
stored_doc = call_args[1]['vectors'][0]['metadata']['doc']
|
||||
assert stored_doc == large_content
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
assert hasattr(args, 'cloud')
|
||||
assert args.cloud == 'aws'
|
||||
assert hasattr(args, 'region')
|
||||
assert args.region == 'us-east-1'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io',
|
||||
'--cloud', 'gcp',
|
||||
'--region', 'us-west1'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
assert args.cloud == 'gcp'
|
||||
assert args.region == 'us-west1'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.doc_embeddings.pinecone.write.DocumentEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.pinecone.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.doc_embeddings.pinecone.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts document chunks/vector pairs and writes them to a Pinecone store.\n"
|
||||
)
|
||||
354
tests/unit/test_storage/test_graph_embeddings_milvus_storage.py
Normal file
354
tests/unit/test_storage/test_graph_embeddings_milvus_storage.py
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
"""
|
||||
Tests for Milvus graph embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import Value, EntityEmbeddings
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||
"""Test cases for Milvus graph embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entities with embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity1', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value='literal entity', is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors') as mock_entity_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-ge-storage',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.EntityVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_single_entity(self, processor):
|
||||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each entity
|
||||
expected_calls = [
|
||||
# Entity 1 vectors
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity1'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity1'),
|
||||
# Entity 2 vectors
|
||||
([0.7, 0.8, 0.9], 'literal entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called for empty entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_none_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called for None entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_mixed_valid_invalid_entities(self, processor):
|
||||
"""Test storing graph embeddings with mix of valid and invalid entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/valid', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify only valid entity was inserted
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], 'http://example.com/valid'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entities_list(self, processor):
|
||||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
|
||||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
([0.1, 0.2], 'http://example.com/entity'),
|
||||
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
|
||||
([0.7, 0.8, 0.9], 'http://example.com/entity'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_uri_and_literal_entities(self, processor):
|
||||
"""Test storing graph embeddings for both URI and literal entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/uri_entity', is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
literal_entity = EntityEmbeddings(
|
||||
entity=Value(value='literal entity text', is_uri=False),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify both entities were inserted
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], 'http://example.com/uri_entity'),
|
||||
([0.4, 0.5, 0.6], 'literal entity text'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.milvus.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.milvus.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Milvus store.\n"
|
||||
)
|
||||
|
|
@ -0,0 +1,460 @@
|
|||
"""
|
||||
Tests for Pinecone graph embeddings storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.graph_embeddings.pinecone.write import Processor
|
||||
from trustgraph.schema import EntityEmbeddings, Value
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||
"""Test cases for Pinecone graph embeddings storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entity embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value="entity2", is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-ge-storage',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
assert processor.cloud == 'aws'
|
||||
assert processor.region == 'us-east-1'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key',
|
||||
cloud='gcp',
|
||||
region='us-west1'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
assert processor.cloud == 'gcp'
|
||||
assert processor.region == 'us-west1'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_single_entity(self, processor):
|
||||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
assert mock_index.upsert.call_count == 2
|
||||
|
||||
# Check first vector upsert
|
||||
first_call = mock_index.upsert.call_args_list[0]
|
||||
first_vectors = first_call[1]['vectors']
|
||||
assert len(first_vectors) == 1
|
||||
assert first_vectors[0]['id'] == 'id1'
|
||||
assert first_vectors[0]['values'] == [0.1, 0.2, 0.3]
|
||||
assert first_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
|
||||
# Check second vector upsert
|
||||
second_call = mock_index.upsert.call_args_list[1]
|
||||
second_vectors = second_call[1]['vectors']
|
||||
assert len(second_vectors) == 1
|
||||
assert second_vectors[0]['id'] == 'id2'
|
||||
assert second_vectors[0]['values'] == [0.4, 0.5, 0.6]
|
||||
assert second_vectors[0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Verify upsert was called for each vector (3 total)
|
||||
assert mock_index.upsert.call_count == 3
|
||||
|
||||
# Verify entity values in metadata
|
||||
calls = mock_index.upsert.call_args_list
|
||||
assert calls[0][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
assert calls[1][1]['vectors'][0]['metadata']['entity'] == "http://example.org/entity1"
|
||||
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for empty entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_none_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called for None entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
mock_index_3d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
elif name.endswith("-3"):
|
||||
return mock_index_3d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 3
|
||||
mock_index_2d.upsert.assert_called_once()
|
||||
mock_index_4d.upsert.assert_called_once()
|
||||
mock_index_3d.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entities_list(self, processor):
|
||||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_entity_with_no_vectors(self, processor):
|
||||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added by parsing empty args
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
assert hasattr(args, 'cloud')
|
||||
assert args.cloud == 'aws'
|
||||
assert hasattr(args, 'region')
|
||||
assert args.region == 'us-east-1'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io',
|
||||
'--cloud', 'gcp',
|
||||
'--region', 'us-west1'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
assert args.cloud == 'gcp'
|
||||
assert args.region == 'us-west1'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.graph_embeddings.pinecone.write.GraphEmbeddingsStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.pinecone.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.graph_embeddings.pinecone.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nAccepts entity/vector pairs and writes them to a Pinecone store.\n"
|
||||
)
|
||||
436
tests/unit/test_storage/test_triples_falkordb_storage.py
Normal file
436
tests/unit/test_storage/test_triples_falkordb_storage.py
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
"""
|
||||
Tests for FalkorDB triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.falkordb.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestFalkorDBStorageProcessor:
|
||||
"""Test cases for FalkorDB storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.triples.falkordb.write.FalkorDB') as mock_falkordb:
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-falkordb-storage',
|
||||
graph_url='falkor://localhost:6379',
|
||||
database='test_db'
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'falkordb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
|
||||
mock_client.select_graph.assert_called_once_with('falkordb')
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.FalkorDB')
|
||||
def test_processor_initialization_with_custom_params(self, mock_falkordb):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_url='falkor://custom:6379',
|
||||
database='custom_db'
|
||||
)
|
||||
|
||||
assert processor.db == 'custom_db'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
|
||||
mock_client.select_graph.assert_called_once_with('custom_db')
|
||||
|
||||
def test_create_node(self, processor):
|
||||
"""Test node creation"""
|
||||
test_uri = 'http://example.com/node'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_literal(self, processor):
|
||||
"""Test literal creation"""
|
||||
test_value = 'test literal value'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
params={
|
||||
"value": test_value,
|
||||
},
|
||||
)
|
||||
|
||||
def test_relate_node(self, processor):
|
||||
"""Test node-to-node relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
dest_uri = 'http://example.com/dest'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 0
|
||||
mock_result.run_time_ms = 5
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": dest_uri,
|
||||
"uri": pred_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_relate_literal(self, processor):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
literal_value = 'literal destination'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 0
|
||||
mock_result.run_time_ms = 5
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": literal_value,
|
||||
"uri": pred_uri,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_uri_object(self, processor):
|
||||
"""Test storing triple with URI object"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
|
||||
# Create object node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.io.query.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_literal_object(self, processor, mock_message):
|
||||
"""Test storing triple with literal object"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
|
||||
# Create literal object
|
||||
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.io.query.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_multiple_triples(self, processor):
|
||||
"""Test storing multiple triples"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
||||
# Verify first triple operations
|
||||
first_triple_calls = processor.io.query.call_args_list[0:3]
|
||||
assert first_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject1"
|
||||
assert first_triple_calls[1][1]["params"]["value"] == "literal object1"
|
||||
assert first_triple_calls[2][1]["params"]["src"] == "http://example.com/subject1"
|
||||
|
||||
# Verify second triple operations
|
||||
second_triple_calls = processor.io.query.call_args_list[3:6]
|
||||
assert second_triple_calls[0][1]["params"]["uri"] == "http://example.com/subject2"
|
||||
assert second_triple_calls[1][1]["params"]["uri"] == "http://example.com/object2"
|
||||
assert second_triple_calls[2][1]["params"]["src"] == "http://example.com/subject2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_empty_list(self, processor):
|
||||
"""Test storing empty triples list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_mixed_objects(self, processor):
|
||||
"""Test storing triples with mixed URI and literal objects"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
||||
# Verify first triple creates literal
|
||||
assert "Literal" in processor.io.query.call_args_list[1][0][0]
|
||||
assert processor.io.query.call_args_list[1][1]["params"]["value"] == "literal object"
|
||||
|
||||
# Verify second triple creates node
|
||||
assert "Node" in processor.io.query.call_args_list[4][0][0]
|
||||
assert processor.io.query.call_args_list[4][1]["params"]["uri"] == "http://example.com/object2"
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_url')
|
||||
assert args.graph_url == 'falkor://falkordb:6379'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'falkordb'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-url', 'falkor://custom:6379',
|
||||
'--database', 'custom_db'
|
||||
])
|
||||
|
||||
assert args.graph_url == 'falkor://custom:6379'
|
||||
assert args.database == 'custom_db'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.falkordb.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'falkor://short:6379'])
|
||||
|
||||
assert args.graph_url == 'falkor://short:6379'
|
||||
|
||||
@patch('trustgraph.storage.triples.falkordb.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.falkordb.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to FalkorDB graph.\n"
|
||||
)
|
||||
|
||||
def test_create_node_with_special_characters(self, processor):
|
||||
"""Test node creation with special characters in URI"""
|
||||
test_uri = 'http://example.com/node with spaces & symbols'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def test_create_literal_with_special_characters(self, processor):
|
||||
"""Test literal creation with special characters"""
|
||||
test_value = 'literal with "quotes" and \n newlines'
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes_created = 1
|
||||
mock_result.run_time_ms = 10
|
||||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
params={
|
||||
"value": test_value,
|
||||
},
|
||||
)
|
||||
441
tests/unit/test_storage/test_triples_memgraph_storage.py
Normal file
441
tests/unit/test_storage/test_triples_memgraph_storage.py
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Tests for Memgraph triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestMemgraphStorageProcessor:
|
||||
"""Test cases for Memgraph storage processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') as mock_graph_db:
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-memgraph-storage',
|
||||
graph_host='bolt://localhost:7687',
|
||||
username='test_user',
|
||||
password='test_pass',
|
||||
database='test_db'
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'memgraph'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://memgraph:7687',
|
||||
auth=('memgraph', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='custom_user',
|
||||
password='custom_pass',
|
||||
database='custom_db'
|
||||
)
|
||||
|
||||
assert processor.db == 'custom_db'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('custom_user', 'custom_pass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_indexes_success(self, mock_graph_db):
|
||||
"""Test successful index creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify index creation calls
|
||||
expected_calls = [
|
||||
"CREATE INDEX ON :Node",
|
||||
"CREATE INDEX ON :Node(uri)",
|
||||
"CREATE INDEX ON :Literal",
|
||||
"CREATE INDEX ON :Literal(value)"
|
||||
]
|
||||
|
||||
assert mock_session.run.call_count == len(expected_calls)
|
||||
for i, expected_call in enumerate(expected_calls):
|
||||
actual_call = mock_session.run.call_args_list[i][0][0]
|
||||
assert actual_call == expected_call
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_indexes_with_exceptions(self, mock_graph_db):
|
||||
"""Test index creation with exceptions (should be ignored)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Make all index creation calls raise exceptions
|
||||
mock_session.run.side_effect = Exception("Index already exists")
|
||||
|
||||
# Should not raise an exception
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify all index creation calls were attempted
|
||||
assert mock_session.run.call_count == 4
|
||||
|
||||
def test_create_node(self, processor):
|
||||
"""Test node creation"""
|
||||
test_uri = 'http://example.com/node'
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri=test_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_create_literal(self, processor):
|
||||
"""Test literal creation"""
|
||||
test_value = 'test literal value'
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value=test_value,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_relate_node(self, processor):
|
||||
"""Test node-to-node relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
dest_uri = 'http://example.com/dest'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 5
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src_uri, dest=dest_uri, uri=pred_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_relate_literal(self, processor):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
src_uri = 'http://example.com/src'
|
||||
pred_uri = 'http://example.com/pred'
|
||||
literal_value = 'literal destination'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 5
|
||||
mock_result.summary = mock_summary
|
||||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value)
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src=src_uri, dest=literal_value, uri=pred_uri,
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
def test_create_triple_with_uri_object(self, processor):
|
||||
"""Test triple creation with URI object"""
|
||||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple)
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
|
||||
# Create object node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/object'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
for i, (expected_query, expected_params) in enumerate(expected_calls):
|
||||
actual_call = mock_tx.run.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_query
|
||||
assert actual_call[1] == expected_params
|
||||
|
||||
def test_create_triple_with_literal_object(self, processor):
|
||||
"""Test triple creation with literal object"""
|
||||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple)
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri})", {'uri': 'http://example.com/subject'}),
|
||||
# Create literal object
|
||||
("MERGE (n:Literal {value: $value})", {'value': 'literal object'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
for i, (expected_query, expected_params) in enumerate(expected_calls):
|
||||
actual_call = mock_tx.run.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_query
|
||||
assert actual_call[1] == expected_params
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_single_triple(self, processor, mock_message):
|
||||
"""Test storing a single triple"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify session was created with correct database
|
||||
processor.io.session.assert_called_once_with(database=processor.db)
|
||||
|
||||
# Verify execute_write was called once per triple
|
||||
mock_session.execute_write.assert_called_once()
|
||||
|
||||
# Verify the triple was passed to create_triple
|
||||
call_args = mock_session.execute_write.call_args
|
||||
assert call_args[0][0] == processor.create_triple
|
||||
assert call_args[0][1] == mock_message.triples[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_multiple_triples(self, processor):
|
||||
"""Test storing multiple triples"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
# Create message with multiple triples
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify session was called twice (once per triple)
|
||||
assert processor.io.session.call_count == 2
|
||||
|
||||
# Verify execute_write was called once per triple
|
||||
assert mock_session.execute_write.call_count == 2
|
||||
|
||||
# Verify each triple was processed
|
||||
call_args_list = mock_session.execute_write.call_args_list
|
||||
assert call_args_list[0][0][1] == triple1
|
||||
assert call_args_list[1][0][1] == triple2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_empty_list(self, processor):
|
||||
"""Test storing empty triples list"""
|
||||
mock_session = MagicMock()
|
||||
processor.io.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Reset the mock to clear the initialization call
|
||||
processor.io.session.reset_mock()
|
||||
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no session calls were made (no triples to process)
|
||||
processor.io.session.assert_not_called()
|
||||
|
||||
# Verify no execute_write calls were made
|
||||
mock_session.execute_write.assert_not_called()
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'memgraph'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'memgraph'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'custom_user',
|
||||
'--password', 'custom_pass',
|
||||
'--database', 'custom_db'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'custom_user'
|
||||
assert args.password == 'custom_pass'
|
||||
assert args.database == 'custom_db'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.memgraph.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to Memgraph.\n"
|
||||
)
|
||||
548
tests/unit/test_storage/test_triples_neo4j_storage.py
Normal file
548
tests/unit/test_storage/test_triples_neo4j_storage.py
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
"""
|
||||
Tests for Neo4j triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor
|
||||
|
||||
|
||||
class TestNeo4jStorageProcessor:
|
||||
"""Test cases for Neo4j storage processor"""
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'neo4j'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://neo4j:7687',
|
||||
auth=('neo4j', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='testuser',
|
||||
password='testpass',
|
||||
database='testdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'testdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('testuser', 'testpass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_indexes_success(self, mock_graph_db):
|
||||
"""Test successful index creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify index creation queries were executed
|
||||
expected_calls = [
|
||||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
|
||||
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)"
|
||||
]
|
||||
|
||||
assert mock_session.run.call_count == 3
|
||||
for expected_query in expected_calls:
|
||||
mock_session.run.assert_any_call(expected_query)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_indexes_with_exceptions(self, mock_graph_db):
|
||||
"""Test index creation with exceptions (should be ignored)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Make session.run raise exceptions
|
||||
mock_session.run.side_effect = Exception("Index already exists")
|
||||
|
||||
# Should not raise exception - they should be caught and ignored
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Should have tried to create all 3 indexes despite exceptions
|
||||
assert mock_session.run.call_count == 3
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_node(self, mock_graph_db):
|
||||
"""Test node creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_node
|
||||
processor.create_node("http://example.com/node")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri="http://example.com/node",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_create_literal(self, mock_graph_db):
|
||||
"""Test literal creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_literal
|
||||
processor.create_literal("literal value")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value="literal value",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_relate_node(self, mock_graph_db):
|
||||
"""Test node-to-node relationship creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test relate_node
|
||||
processor.relate_node(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/object"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="http://example.com/object",
|
||||
uri="http://example.com/predicate",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_relate_literal(self, mock_graph_db):
|
||||
"""Test node-to-literal relationship creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test relate_literal
|
||||
processor.relate_literal(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"literal value"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal value",
|
||||
uri="http://example.com/predicate",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_triples_with_uri_object(self, mock_graph_db):
|
||||
"""Test handling triples message with URI object"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject and object
|
||||
# Verify relate_node was called
|
||||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/subject", "database_": "neo4j"}
|
||||
),
|
||||
# Object node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/object", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Node {uri: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "http://example.com/object",
|
||||
"uri": "http://example.com/predicate",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
for expected_query, expected_params in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_literal_object(self, mock_graph_db):
|
||||
"""Test handling triples message with literal object"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triple with literal object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "literal value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject
|
||||
# Verify create_literal was called for object
|
||||
# Verify relate_literal was called
|
||||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
{"uri": "http://example.com/subject", "database_": "neo4j"}
|
||||
),
|
||||
# Literal creation
|
||||
(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
{"value": "literal value", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "literal value",
|
||||
"uri": "http://example.com/predicate",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
for expected_query, expected_params in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(expected_query, **expected_params)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_multiple_triples(self, mock_graph_db):
|
||||
"""Test handling message with multiple triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = "http://example.com/subject1"
|
||||
triple1.p.value = "http://example.com/predicate1"
|
||||
triple1.o.value = "http://example.com/object1"
|
||||
triple1.o.is_uri = True
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = "http://example.com/subject2"
|
||||
triple2.p.value = "http://example.com/predicate2"
|
||||
triple2.o.value = "literal value"
|
||||
triple2.o.is_uri = False
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should have processed both triples
|
||||
# Triple1: 2 nodes + 1 relationship = 3 calls
|
||||
# Triple2: 1 node + 1 literal + 1 relationship = 3 calls
|
||||
# Total: 6 calls
|
||||
assert mock_driver.execute_query.call_count == 6
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_empty_triples(self, mock_graph_db):
|
||||
"""Test handling message with no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should not have made any execute_query calls beyond index creation
|
||||
# Only index creation calls should have been made during initialization
|
||||
mock_driver.execute_query.assert_not_called()
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'neo4j'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph_host', 'bolt://custom:7687',
|
||||
'--username', 'testuser',
|
||||
'--password', 'testpass',
|
||||
'--database', 'testdb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'testuser'
|
||||
assert args.password == 'testpass'
|
||||
assert args.database == 'testdb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.neo4j.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph writer. Input is graph edge. Writes edges to Neo4j graph.\n"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_special_characters(self, mock_graph_db):
|
||||
"""Test handling triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject with spaces"
|
||||
triple.p.value = "http://example.com/predicate:with/symbols"
|
||||
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
|
||||
triple.o.is_uri = False
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was processed with special characters preserved
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Node {uri: $uri})",
|
||||
uri="http://example.com/subject with spaces",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value})",
|
||||
value='literal with "quotes" and unicode: ñáéíóú',
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MATCH (src:Node {uri: $src}) "
|
||||
"MATCH (dest:Literal {value: $dest}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
|
||||
src="http://example.com/subject with spaces",
|
||||
dest='literal with "quotes" and unicode: ñáéíóú',
|
||||
uri="http://example.com/predicate:with/symbols",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
@ -3,6 +3,9 @@ from cassandra.cluster import Cluster
|
|||
from cassandra.auth import PlainTextAuthProvider
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
|
||||
# Global list to track clusters for cleanup
|
||||
_active_clusters = []
|
||||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -24,6 +27,9 @@ class TrustGraph:
|
|||
else:
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
# Track this cluster globally
|
||||
_active_clusters.append(self.cluster)
|
||||
|
||||
self.init()
|
||||
|
||||
|
|
@ -119,3 +125,13 @@ class TrustGraph:
|
|||
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the Cassandra session and cluster connections properly"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.shutdown()
|
||||
if hasattr(self, 'cluster') and self.cluster:
|
||||
self.cluster.shutdown()
|
||||
# Remove from global tracking
|
||||
if self.cluster in _active_clusters:
|
||||
_active_clusters.remove(self.cluster)
|
||||
|
|
|
|||
|
|
@ -5,94 +5,56 @@ of chunks
|
|||
"""
|
||||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import document_embeddings_request_queue
|
||||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
module = "de-query"
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "de-query"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(DocumentEmbeddingsQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddingsRequest,
|
||||
"output_schema": DocumentEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = DocVectors(store_uri)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in v.vectors:
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=v.limit)
|
||||
resp = self.vecstore.search(vec, limit=msg.limit)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
chunk = chunk.encode("utf-8")
|
||||
chunks.append(chunk)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = DocumentEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
documents=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
DocumentEmbeddingsQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -102,5 +64,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,30 +10,21 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
|||
import uuid
|
||||
import os
|
||||
|
||||
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import document_embeddings_request_queue
|
||||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
module = "de-query"
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "de-query"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(DocumentEmbeddingsQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
self.url = params.get("url", None)
|
||||
self.api_key = params.get("api_key", default_api_key)
|
||||
|
||||
if self.api_key is None or self.api_key == "not-specified":
|
||||
raise RuntimeError("Pinecone API key must be specified")
|
||||
|
||||
if self.url:
|
||||
|
||||
self.pinecone = PineconeGRPC(
|
||||
|
|
@ -47,88 +38,53 @@ class Processor(ConsumerProducer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddingsRequest,
|
||||
"output_schema": DocumentEmbeddingsResponse,
|
||||
"url": self.url,
|
||||
"api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_document_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in v.vectors:
|
||||
for vec in msg.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"d-" + v.user + "-" + str(dim)
|
||||
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||
)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
results = index.query(
|
||||
namespace=v.collection,
|
||||
vector=vec,
|
||||
top_k=v.limit,
|
||||
top_k=msg.limit,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=v.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in results.matches:
|
||||
doc = r.metadata["doc"]
|
||||
chunks.add(doc)
|
||||
chunks.append(doc)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = DocumentEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
documents=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
DocumentEmbeddingsQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-a', '--api-key',
|
||||
|
|
@ -143,5 +99,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,35 +5,21 @@ entities
|
|||
"""
|
||||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import graph_embeddings_request_queue
|
||||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
module = "ge-query"
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "ge-query"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(GraphEmbeddingsQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddingsRequest,
|
||||
"output_schema": GraphEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
|
@ -46,29 +32,34 @@ class Processor(ConsumerProducer):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_graph_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
for vec in msg.vectors:
|
||||
|
||||
entities = set()
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=v.limit)
|
||||
resp = self.vecstore.search(vec, limit=msg.limit * 2)
|
||||
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
entities.add(ent)
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
|
||||
# Convert set to list
|
||||
entities = list(entities)
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
|
|
@ -78,36 +69,19 @@ class Processor(ConsumerProducer):
|
|||
entities = ents2
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
return entities
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = GraphEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
entities=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
GraphEmbeddingsQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -117,5 +91,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,30 +10,23 @@ from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
|||
import uuid
|
||||
import os
|
||||
|
||||
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import graph_embeddings_request_queue
|
||||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
module = "ge-query"
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "ge-query"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(GraphEmbeddingsQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
self.url = params.get("url", None)
|
||||
self.api_key = params.get("api_key", default_api_key)
|
||||
|
||||
if self.api_key is None or self.api_key == "not-specified":
|
||||
raise RuntimeError("Pinecone API key must be specified")
|
||||
|
||||
if self.url:
|
||||
|
||||
self.pinecone = PineconeGRPC(
|
||||
|
|
@ -47,12 +40,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddingsRequest,
|
||||
"output_schema": GraphEmbeddingsResponse,
|
||||
"url": self.url,
|
||||
"api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -62,26 +51,23 @@ class Processor(ConsumerProducer):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_graph_embeddings(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in v.vectors:
|
||||
for vec in msg.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + v.user + "-" + str(dim)
|
||||
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||
)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
|
@ -89,9 +75,8 @@ class Processor(ConsumerProducer):
|
|||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
results = index.query(
|
||||
namespace=v.collection,
|
||||
vector=vec,
|
||||
top_k=v.limit * 2,
|
||||
top_k=msg.limit * 2,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
|
@ -106,10 +91,10 @@ class Processor(ConsumerProducer):
|
|||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
|
|
@ -118,37 +103,17 @@ class Processor(ConsumerProducer):
|
|||
|
||||
entities = ents2
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return entities
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = GraphEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
entities=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
GraphEmbeddingsQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-a', '--api-key',
|
||||
|
|
@ -163,5 +128,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,37 +9,24 @@ from falkordb import FalkorDB
|
|||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
module = "triples-query"
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_url = 'falkor://falkordb:6379'
|
||||
default_database = 'falkordb'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(TriplesQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_url = params.get("graph_host", default_graph_url)
|
||||
graph_url = params.get("graph_url", default_graph_url)
|
||||
database = params.get("database", default_database)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_url": graph_url,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -54,50 +41,45 @@ class Processor(ConsumerProducer):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
if query.s is not None:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SPO
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"RETURN $src as src",
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"rel": v.p.value,
|
||||
"value": v.o.value,
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"value": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN $src as src",
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"rel": v.p.value,
|
||||
"uri": v.o.value,
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
"uri": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -105,116 +87,124 @@ class Processor(ConsumerProducer):
|
|||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"RETURN dest.value as dest",
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"rel": v.p.value,
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, rec[0]))
|
||||
triples.append((query.s.value, query.p.value, rec[0]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"RETURN dest.uri as dest",
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"rel": v.p.value,
|
||||
"src": query.s.value,
|
||||
"rel": query.p.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, rec[0]))
|
||||
triples.append((query.s.value, query.p.value, rec[0]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SO
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN rel.uri as rel",
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"value": v.o.value,
|
||||
"src": query.s.value,
|
||||
"value": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, rec[0], v.o.value))
|
||||
triples.append((query.s.value, rec[0], query.o.value))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN rel.uri as rel",
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"uri": v.o.value,
|
||||
"src": query.s.value,
|
||||
"uri": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, rec[0], v.o.value))
|
||||
triples.append((query.s.value, rec[0], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
# s
|
||||
|
||||
records = self.io.query(
|
||||
"match (src:node {uri: $src})-[rel:rel]->(dest:literal) "
|
||||
"return rel.uri as rel, dest.value as dest",
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"src": query.s.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, rec[0], rec[1]))
|
||||
triples.append((query.s.value, rec[0], rec[1]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest",
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"src": v.s.value,
|
||||
"src": query.s.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, rec[0], rec[1]))
|
||||
triples.append((query.s.value, rec[0], rec[1]))
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if v.p is not None:
|
||||
if query.p is not None:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# PO
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src",
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": v.p.value,
|
||||
"value": v.o.value,
|
||||
"uri": query.p.value,
|
||||
"value": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], v.p.value, v.o.value))
|
||||
triples.append((rec[0], query.p.value, query.o.value))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src",
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": v.p.value,
|
||||
"dest": v.o.value,
|
||||
"uri": query.p.value,
|
||||
"dest": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], v.p.value, v.o.value))
|
||||
triples.append((rec[0], query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -222,53 +212,57 @@ class Processor(ConsumerProducer):
|
|||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"RETURN src.uri as src, dest.value as dest",
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": v.p.value,
|
||||
"uri": query.p.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], v.p.value, rec[1]))
|
||||
triples.append((rec[0], query.p.value, rec[1]))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"RETURN src.uri as src, dest.uri as dest",
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": v.p.value,
|
||||
"uri": query.p.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], v.p.value, rec[1]))
|
||||
triples.append((rec[0], query.p.value, rec[1]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# O
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"value": v.o.value,
|
||||
"value": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], rec[1], v.o.value))
|
||||
triples.append((rec[0], rec[1], query.o.value))
|
||||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(query.limit),
|
||||
params={
|
||||
"uri": v.o.value,
|
||||
"uri": query.o.value,
|
||||
},
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
triples.append((rec[0], rec[1], v.o.value))
|
||||
triples.append((rec[0], rec[1], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -276,7 +270,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
|
|
@ -284,7 +279,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
records = self.io.query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(query.limit),
|
||||
).result_set
|
||||
|
||||
for rec in records:
|
||||
|
|
@ -296,40 +292,20 @@ class Processor(ConsumerProducer):
|
|||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
)
|
||||
for t in triples
|
||||
for t in triples[:query.limit]
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return triples
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
TriplesQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-url',
|
||||
|
|
@ -345,5 +321,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,28 +9,19 @@ from neo4j import GraphDatabase
|
|||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
module = "triples-query"
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_host = 'bolt://memgraph:7687'
|
||||
default_username = 'memgraph'
|
||||
default_password = 'password'
|
||||
default_database = 'memgraph'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(TriplesQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
username = params.get("username", default_username)
|
||||
password = params.get("password", default_password)
|
||||
|
|
@ -38,12 +29,9 @@ class Processor(ConsumerProducer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_host": graph_host,
|
||||
"username": username,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -58,46 +46,39 @@ class Processor(ConsumerProducer):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
if query.s is not None:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SPO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, rel=v.p.value, value=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, rel=v.p.value, uri=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -106,56 +87,56 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, rel=v.p.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, rel=v.p.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, rel=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, value=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value, uri=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value, uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -164,59 +145,59 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
src=v.s.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
src=query.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if v.p is not None:
|
||||
if query.p is not None:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# PO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(v.limit),
|
||||
uri=v.p.value, value=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT " + str(v.limit),
|
||||
uri=v.p.value, dest=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -225,56 +206,56 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
uri=v.p.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"RETURN src.uri as src, dest.uri as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
uri=v.p.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# O
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(v.limit),
|
||||
value=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT " + str(v.limit),
|
||||
uri=v.o.value,
|
||||
"LIMIT " + str(query.limit),
|
||||
uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -283,7 +264,7 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
"LIMIT " + str(query.limit),
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -294,7 +275,7 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT " + str(v.limit),
|
||||
"LIMIT " + str(query.limit),
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
|
|
@ -308,40 +289,22 @@ class Processor(ConsumerProducer):
|
|||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
)
|
||||
for t in triples[:v.limit]
|
||||
for t in triples[:query.limit]
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return triples
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
print(f"Exception: {e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
TriplesQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
|
|
@ -369,5 +332,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,28 +9,19 @@ from neo4j import GraphDatabase
|
|||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... base import TriplesQueryService
|
||||
|
||||
module = "triples-query"
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-query"
|
||||
|
||||
default_graph_host = 'bolt://neo4j:7687'
|
||||
default_username = 'neo4j'
|
||||
default_password = 'password'
|
||||
default_database = 'neo4j'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(TriplesQueryService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
username = params.get("username", default_username)
|
||||
password = params.get("password", default_password)
|
||||
|
|
@ -38,12 +29,9 @@ class Processor(ConsumerProducer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_host": graph_host,
|
||||
"username": username,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -58,44 +46,37 @@ class Processor(ConsumerProducer):
|
|||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
if query.s is not None:
|
||||
if query.p is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SPO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"RETURN $src as src",
|
||||
src=v.s.value, rel=v.p.value, value=v.o.value,
|
||||
src=query.s.value, rel=query.p.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN $src as src",
|
||||
src=v.s.value, rel=v.p.value, uri=v.o.value,
|
||||
src=query.s.value, rel=query.p.value, uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
triples.append((query.s.value, query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -104,52 +85,52 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"RETURN dest.value as dest",
|
||||
src=v.s.value, rel=v.p.value,
|
||||
src=query.s.value, rel=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"RETURN dest.uri as dest",
|
||||
src=v.s.value, rel=v.p.value,
|
||||
src=query.s.value, rel=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
triples.append((query.s.value, query.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# SO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=v.s.value, value=v.o.value,
|
||||
src=query.s.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=v.s.value, uri=v.o.value,
|
||||
src=query.s.value, uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
triples.append((query.s.value, data["rel"], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -158,55 +139,55 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN rel.uri as rel, dest.value as dest",
|
||||
src=v.s.value,
|
||||
src=query.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest",
|
||||
src=v.s.value,
|
||||
src=query.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
triples.append((query.s.value, data["rel"], data["dest"]))
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if v.p is not None:
|
||||
if query.p is not None:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# PO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=v.p.value, value=v.o.value,
|
||||
uri=query.p.value, value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $dest}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=v.p.value, dest=v.o.value,
|
||||
uri=query.p.value, dest=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
triples.append((data["src"], query.p.value, query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -215,52 +196,52 @@ class Processor(ConsumerProducer):
|
|||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"RETURN src.uri as src, dest.value as dest",
|
||||
uri=v.p.value,
|
||||
uri=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"RETURN src.uri as src, dest.uri as dest",
|
||||
uri=v.p.value,
|
||||
uri=query.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
triples.append((data["src"], query.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
if query.o is not None:
|
||||
|
||||
# O
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
value=v.o.value,
|
||||
value=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
uri=v.o.value,
|
||||
uri=query.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
triples.append((data["src"], data["rel"], query.o.value))
|
||||
|
||||
else:
|
||||
|
||||
|
|
@ -295,37 +276,17 @@ class Processor(ConsumerProducer):
|
|||
for t in triples
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
return triples
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
await self.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
TriplesQueryService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
|
|
@ -353,5 +314,5 @@ class Processor(ConsumerProducer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,58 +4,41 @@ Accepts entity/vector pairs and writes them to a Milvus store.
|
|||
"""
|
||||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
|
||||
from .... schema import DocumentEmbeddings
|
||||
from .... schema import document_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "de-write"
|
||||
|
||||
default_input_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "de-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(DocumentEmbeddingsStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = DocVectors(store_uri)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
for emb in v.chunks:
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "" or chunk is None: continue
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
if chunk != "" and v.chunk is not None:
|
||||
for vec in v.vectors:
|
||||
self.vecstore.insert(vec, chunk)
|
||||
self.vecstore.insert(vec, chunk)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
DocumentEmbeddingsStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -65,5 +48,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,42 +1,32 @@
|
|||
|
||||
"""
|
||||
Accepts entity/vector pairs and writes them to a Qdrant store.
|
||||
Accepts document chunks/vector pairs and writes them to a Pinecone store.
|
||||
"""
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from .... schema import DocumentEmbeddings
|
||||
from .... schema import document_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... base import DocumentEmbeddingsStoreService
|
||||
|
||||
module = "de-write"
|
||||
|
||||
default_input_queue = document_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "de-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(DocumentEmbeddingsStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
self.url = params.get("url", None)
|
||||
self.cloud = params.get("cloud", default_cloud)
|
||||
self.region = params.get("region", default_region)
|
||||
self.api_key = params.get("api_key", default_api_key)
|
||||
|
||||
if self.api_key is None:
|
||||
if self.api_key is None or self.api_key == "not-specified":
|
||||
raise RuntimeError("Pinecone API key must be specified")
|
||||
|
||||
if self.url:
|
||||
|
|
@ -52,94 +42,96 @@ class Processor(Consumer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddings,
|
||||
"url": self.url,
|
||||
"cloud": self.cloud,
|
||||
"region": self.region,
|
||||
"api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.last_index_name = None
|
||||
|
||||
async def handle(self, msg):
|
||||
def create_index(self, index_name, dim):
|
||||
|
||||
v = msg.value()
|
||||
self.pinecone.create_index(
|
||||
name = index_name,
|
||||
dimension = dim,
|
||||
metric = "cosine",
|
||||
spec = ServerlessSpec(
|
||||
cloud = self.cloud,
|
||||
region = self.region,
|
||||
)
|
||||
)
|
||||
|
||||
for emb in v.chunks:
|
||||
for i in range(0, 1000):
|
||||
|
||||
if self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if not self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
raise RuntimeError(
|
||||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def store_document_embeddings(self, message):
|
||||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "" or chunk is None: continue
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
|
||||
for vec in v.vectors:
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
"d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
|
||||
)
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"d-" + v.metadata.user + "-" + str(dim)
|
||||
)
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
try:
|
||||
|
||||
try:
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
self.pinecone.create_index(
|
||||
name = index_name,
|
||||
dimension = dim,
|
||||
metric = "cosine",
|
||||
spec = ServerlessSpec(
|
||||
cloud = self.cloud,
|
||||
region = self.region,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
for i in range(0, 1000):
|
||||
print(f"Index {index_name} created", flush=True)
|
||||
|
||||
if self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
break
|
||||
self.last_index_name = index_name
|
||||
|
||||
time.sleep(1)
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
if not self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
raise RuntimeError(
|
||||
"Gave up waiting for index creation"
|
||||
)
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
|
||||
except Exception as e:
|
||||
print("Pinecone index creation failed")
|
||||
raise e
|
||||
records = [
|
||||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "doc": chunk },
|
||||
}
|
||||
]
|
||||
|
||||
print(f"Index {index_name} created", flush=True)
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": id,
|
||||
"values": vec,
|
||||
"metadata": { "doc": chunk },
|
||||
}
|
||||
]
|
||||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
namespace = v.metadata.collection,
|
||||
)
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
DocumentEmbeddingsStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-a', '--api-key',
|
||||
|
|
@ -166,5 +158,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,42 +3,29 @@
|
|||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
from .... schema import GraphEmbeddings
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... base import Consumer
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
|
||||
module = "ge-write"
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "ge-write"
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(GraphEmbeddingsStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = EntityVectors(store_uri)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
for entity in v.entities:
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
for vec in entity.vectors:
|
||||
|
|
@ -47,9 +34,7 @@ class Processor(Consumer):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
GraphEmbeddingsStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
|
|
@ -59,5 +44,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,32 +10,23 @@ import time
|
|||
import uuid
|
||||
import os
|
||||
|
||||
from .... schema import GraphEmbeddings
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... base import GraphEmbeddingsStoreService
|
||||
|
||||
module = "ge-write"
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "ge-write"
|
||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||
default_cloud = "aws"
|
||||
default_region = "us-east-1"
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(GraphEmbeddingsStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
self.url = params.get("url", None)
|
||||
self.cloud = params.get("cloud", default_cloud)
|
||||
self.region = params.get("region", default_region)
|
||||
self.api_key = params.get("api_key", default_api_key)
|
||||
|
||||
if self.api_key is None:
|
||||
if self.api_key is None or self.api_key == "not-specified":
|
||||
raise RuntimeError("Pinecone API key must be specified")
|
||||
|
||||
if self.url:
|
||||
|
|
@ -51,10 +42,10 @@ class Processor(Consumer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddings,
|
||||
"url": self.url,
|
||||
"cloud": self.cloud,
|
||||
"region": self.region,
|
||||
"api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -88,13 +79,9 @@ class Processor(Consumer):
|
|||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_graph_embeddings(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
for entity in v.entities:
|
||||
for entity in message.entities:
|
||||
|
||||
if entity.entity.value == "" or entity.entity.value is None:
|
||||
continue
|
||||
|
|
@ -104,7 +91,7 @@ class Processor(Consumer):
|
|||
dim = len(vec)
|
||||
|
||||
index_name = (
|
||||
"t-" + v.metadata.user + "-" + str(dim)
|
||||
"t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
|
||||
)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
|
|
@ -125,9 +112,12 @@ class Processor(Consumer):
|
|||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": id,
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "entity": entity.entity.value },
|
||||
}
|
||||
|
|
@ -135,15 +125,12 @@ class Processor(Consumer):
|
|||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
namespace = v.metadata.collection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
GraphEmbeddingsStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-a', '--api-key',
|
||||
|
|
@ -170,5 +157,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,34 +11,24 @@ import time
|
|||
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from .... schema import Triples
|
||||
from .... schema import triples_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... base import TriplesStoreService
|
||||
|
||||
module = "triples-write"
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-write"
|
||||
|
||||
default_graph_url = 'falkor://falkordb:6379'
|
||||
default_database = 'falkordb'
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(TriplesStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_url = params.get("graph_host", default_graph_url)
|
||||
graph_url = params.get("graph_url", default_graph_url)
|
||||
database = params.get("database", default_database)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Triples,
|
||||
"graph_url": graph_url,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -118,11 +108,9 @@ class Processor(Consumer):
|
|||
time=res.run_time_ms
|
||||
))
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_triples(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
for t in v.triples:
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value)
|
||||
|
||||
|
|
@ -136,14 +124,12 @@ class Processor(Consumer):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
TriplesStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph_host',
|
||||
'-g', '--graph-url',
|
||||
default=default_graph_url,
|
||||
help=f'Graph host (default: {default_graph_url})'
|
||||
help=f'Graph URL (default: {default_graph_url})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -154,5 +140,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,27 +11,19 @@ import time
|
|||
|
||||
from neo4j import GraphDatabase
|
||||
|
||||
from .... schema import Triples
|
||||
from .... schema import triples_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
from .... base import TriplesStoreService
|
||||
|
||||
module = "triples-write"
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-write"
|
||||
|
||||
default_graph_host = 'bolt://memgraph:7687'
|
||||
default_username = 'memgraph'
|
||||
default_password = 'password'
|
||||
default_database = 'memgraph'
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(TriplesStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
username = params.get("username", default_username)
|
||||
password = params.get("password", default_password)
|
||||
|
|
@ -39,10 +31,10 @@ class Processor(Consumer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Triples,
|
||||
"graph_host": graph_host,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -205,11 +197,9 @@ class Processor(Consumer):
|
|||
src=t.s.value, dest=t.o.value, uri=t.p.value,
|
||||
)
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_triples(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
for t in v.triples:
|
||||
for t in message.triples:
|
||||
|
||||
# self.create_node(t.s.value)
|
||||
|
||||
|
|
@ -226,12 +216,10 @@ class Processor(Consumer):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
TriplesStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph_host',
|
||||
'-g', '--graph-host',
|
||||
default=default_graph_host,
|
||||
help=f'Graph host (default: {default_graph_host})'
|
||||
)
|
||||
|
|
@ -256,5 +244,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,28 +10,21 @@ import argparse
|
|||
import time
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
from .... base import TriplesStoreService
|
||||
|
||||
from .... schema import Triples
|
||||
from .... schema import triples_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = "triples-write"
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
default_ident = "triples-write"
|
||||
|
||||
default_graph_host = 'bolt://neo4j:7687'
|
||||
default_username = 'neo4j'
|
||||
default_password = 'password'
|
||||
default_database = 'neo4j'
|
||||
|
||||
class Processor(Consumer):
|
||||
class Processor(TriplesStoreService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
username = params.get("username", default_username)
|
||||
password = params.get("password", default_password)
|
||||
|
|
@ -39,10 +32,9 @@ class Processor(Consumer):
|
|||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Triples,
|
||||
"graph_host": graph_host,
|
||||
"username": username,
|
||||
"database": database,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -158,11 +150,9 @@ class Processor(Consumer):
|
|||
time=summary.result_available_after
|
||||
))
|
||||
|
||||
async def handle(self, msg):
|
||||
async def store_triples(self, message):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
for t in v.triples:
|
||||
for t in message.triples:
|
||||
|
||||
self.create_node(t.s.value)
|
||||
|
||||
|
|
@ -176,9 +166,7 @@ class Processor(Consumer):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
TriplesStoreService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph_host',
|
||||
|
|
@ -206,5 +194,5 @@ class Processor(Consumer):
|
|||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue