Release/v1.2 (#457)

* Bump setup.py versions for 1.1

* PoC MCP server (#419)

* Very initial MCP server PoC for TrustGraph

* Put service on port 8000

* Add MCP container and packages to buildout

* Update docs for API/CLI changes in 1.0 (#421)

* Update some API basics for the 0.23/1.0 API change

* Add MCP container push (#425)

* Add command args to the MCP server (#426)

* Host and port parameters

* Added websocket arg

* More docs

* MCP client support (#427)

- MCP client service
- Tool request/response schema
- API gateway support for mcp-tool
- Message translation for tool request & response
- Make mcp-tool using configuration service for information
  about where the MCP services are.

* Feature/react call mcp (#428)

Key Features

  - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes
  - API Enhancement: New mcp_tool method for flow-specific tool invocation
  - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration
  - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities
  - Tool Management: Enhanced CLI for tool configuration and management

Changes

  - Added MCP tool invocation to API with flow-specific integration
  - Implemented ToolClientSpec and ToolClient for tool call handling
  - Updated agent-manager-react to invoke MCP tools with configurable types
  - Enhanced CLI with new commands and improved help text
  - Added comprehensive documentation for new CLI commands
  - Improved tool configuration management

Testing

  - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing
  - Enhanced agent capability to invoke multiple tools simultaneously

* Test suite executed from CI pipeline (#433)

* Test strategy & test cases

* Unit tests

* Integration tests

* Extending test coverage (#434)

* Contract tests

* Testing embeedings

* Agent unit tests

* Knowledge pipeline tests

* Turn on contract tests

* Increase storage test coverage (#435)

* Fixing storage and adding tests

* PR pipeline only runs quick tests

* Empty configuration is returned as empty list, previously was not in response (#436)

* Update config util to take files as well as command-line text (#437)

* Updated CLI invocation and config model for tools and mcp (#438)

* Updated CLI invocation and config model for tools and mcp

* CLI anomalies

* Tweaked the MCP tool implementation for new model

* Update agent implementation to match the new model

* Fix agent tools, now all tested

* Fixed integration tests

* Fix MCP delete tool params

* Update Python deps to 1.2

* Update to enable knowledge extraction using the agent framework (#439)

* Implement KG extraction agent (kg-extract-agent)

* Using ReAct framework (agent-manager-react)
 
* ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure.
 
* Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.

* Migrate from setup.py to pyproject.toml (#440)

* Converted setup.py to pyproject.toml

* Modern package infrastructure as recommended by py docs

* Install missing build deps (#441)

* Install missing build deps (#442)

* Implement logging strategy (#444)

* Logging strategy and convert all prints() to logging invocations

* Fix/startup failure (#445)

* Fix loggin startup problems

* Fix logging startup problems (#446)

* Fix logging startup problems (#447)

* Fixed Mistral OCR to use current API (#448)

* Fixed Mistral OCR to use current API

* Added PDF decoder tests

* Fix Mistral OCR ident to be standard pdf-decoder (#450)

* Fix Mistral OCR ident to be standard pdf-decoder

* Correct test

* Schema structure refactor (#451)

* Write schema refactor spec

* Implemented schema refactor spec

* Structure data mvp (#452)

* Structured data tech spec

* Architecture principles

* New schemas

* Updated schemas and specs

* Object extractor

* Add .coveragerc

* New tests

* Cassandra object storage

* Trying to object extraction working, issues exist

* Validate librarian collection (#453)

* Fix token chunker, broken API invocation (#454)

* Fix token chunker, broken API invocation (#455)

* Knowledge load utility CLI (#456)

* Knowledge loader

* More tests
This commit is contained in:
cybermaggedon 2025-08-18 20:56:09 +01:00 committed by GitHub
parent c85ba197be
commit 89be656990
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
509 changed files with 49632 additions and 5159 deletions

View file

@ -0,0 +1,148 @@
"""
Shared fixtures for query tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
@pytest.fixture
def base_query_config():
"""Base configuration for query processors"""
return {
'taskgroup': AsyncMock(),
'id': 'test-query-processor'
}
@pytest.fixture
def qdrant_query_config(base_query_config):
"""Configuration for Qdrant query processors"""
return base_query_config | {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key'
}
@pytest.fixture
def mock_qdrant_client():
"""Mock Qdrant client"""
mock_client = MagicMock()
mock_client.query_points.return_value = []
return mock_client
# Graph embeddings query fixtures
@pytest.fixture
def mock_graph_embeddings_request():
"""Mock graph embeddings request message"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
return mock_message
@pytest.fixture
def mock_graph_embeddings_multiple_vectors():
"""Mock graph embeddings request with multiple vectors"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
return mock_message
@pytest.fixture
def mock_graph_embeddings_query_response():
"""Mock graph embeddings query response from Qdrant"""
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'entity1'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'entity2'}
return [mock_point1, mock_point2]
@pytest.fixture
def mock_graph_embeddings_uri_response():
"""Mock graph embeddings query response with URIs"""
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'http://example.com/entity1'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
mock_point3 = MagicMock()
mock_point3.payload = {'entity': 'regular entity'}
return [mock_point1, mock_point2, mock_point3]
# Document embeddings query fixtures
@pytest.fixture
def mock_document_embeddings_request():
"""Mock document embeddings request message"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
return mock_message
@pytest.fixture
def mock_document_embeddings_multiple_vectors():
"""Mock document embeddings request with multiple vectors"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
return mock_message
@pytest.fixture
def mock_document_embeddings_query_response():
"""Mock document embeddings query response from Qdrant"""
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'first document chunk'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'second document chunk'}
return [mock_point1, mock_point2]
@pytest.fixture
def mock_document_embeddings_utf8_response():
"""Mock document embeddings query response with UTF-8 content"""
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
return [mock_point1, mock_point2]
@pytest.fixture
def mock_empty_query_response():
"""Mock empty query response"""
return []
@pytest.fixture
def mock_large_query_response():
"""Mock large query response with many results"""
mock_points = []
for i in range(10):
mock_point = MagicMock()
mock_point.payload = {'doc': f'document chunk {i}'}
mock_points.append(mock_point)
return mock_points
@pytest.fixture
def mock_mixed_dimension_vectors():
"""Mock request with vectors of different dimensions"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
return mock_message

View file

@ -0,0 +1,456 @@
"""
Tests for Milvus document embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
class TestMilvusDocEmbeddingsQueryProcessor:
"""Test cases for Milvus document embeddings query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') as mock_doc_vectors:
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-de-query',
store_uri='http://localhost:19530'
)
return processor
@pytest.fixture
def mock_query_request(self):
"""Create a mock query request for testing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=10
)
return query
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_doc_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
@pytest.mark.asyncio
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results
mock_results = [
{"entity": {"doc": "First document chunk"}},
{"entity": {"doc": "Second document chunk"}},
{"entity": {"doc": "Third document chunk"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
# Verify results are document chunks
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor):
"""Test querying document embeddings with multiple vectors"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"doc": "Document from first vector"}},
{"entity": {"doc": "Another doc from first vector"}},
]
mock_results_2 = [
{"entity": {"doc": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 3}),
(([0.4, 0.5, 0.6],), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
assert len(result) == 3
assert "Document from first vector" in result
assert "Another doc from first vector" in result
assert "Document from second vector" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
"""Test querying document embeddings respects limit parameter"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=2
)
# Mock search results - more results than limit
mock_results = [
{"entity": {"doc": "Document 1"}},
{"entity": {"doc": "Document 2"}},
{"entity": {"doc": "Document 3"}},
{"entity": {"doc": "Document 4"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2)
# Verify all results are returned (Milvus handles limit internally)
assert len(result) == 4
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors(self, processor):
"""Test querying document embeddings with empty vectors list"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
limit=5
)
result = await processor.query_document_embeddings(query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_search_results(self, processor):
"""Test querying document embeddings with empty search results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_unicode_documents(self, processor):
"""Test querying document embeddings with Unicode document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with Unicode content
mock_results = [
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
{"entity": {"doc": "Regular ASCII document"}},
{"entity": {"doc": "Document with émojis: 😀🎉"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
"""Test querying document embeddings with large document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with large content
large_doc = "A" * 10000 # 10KB of content
mock_results = [
{"entity": {"doc": large_doc}},
{"entity": {"doc": "Small document"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
"""Test querying document embeddings with special characters in documents"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with special characters
mock_results = [
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying document embeddings with zero limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=0
)
result = await processor.query_document_embeddings(query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to zero limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying document embeddings with negative limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=-1
)
result = await processor.query_document_embeddings(query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to negative limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search to raise exception
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying document embeddings with different vector dimensions"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_document_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
assert "Document from 2D vector" in result
assert "Document from 4D vector" in result
assert "Document from 3D vector" in result
@pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor):
"""Test querying document embeddings with duplicate documents in results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=5
)
# Mock search results with duplicates across vectors
mock_results_1 = [
{"entity": {"doc": "Document A"}},
{"entity": {"doc": "Document B"}},
]
mock_results_2 = [
{"entity": {"doc": "Document B"}}, # Duplicate
{"entity": {"doc": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
# This preserves ranking and allows multiple occurrences
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result
assert "Document C" in result
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.query.doc_embeddings.milvus.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.doc_embeddings.milvus.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
)

View file

@ -0,0 +1,558 @@
"""
Tests for Pinecone document embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.pinecone.service import Processor
class TestPineconeDocEmbeddingsQueryProcessor:
"""Test cases for Pinecone document embeddings query processor"""
@pytest.fixture
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-de-query',
api_key='test-api-key'
)
return processor
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
@patch('trustgraph.query.doc_embeddings.pinecone.service.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
@pytest.mark.asyncio
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'First document chunk'}),
MagicMock(metadata={'doc': 'Second document chunk'}),
MagicMock(metadata={'doc': 'Third document chunk'})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "d-test_user-test_collection-3"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
mock_index.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
top_k=3,
include_values=False,
include_metadata=True
)
# Verify results
assert len(chunks) == 3
assert chunks[0] == 'First document chunk'
assert chunks[1] == 'Second document chunk'
assert chunks[2] == 'Third document chunk'
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying document embeddings with multiple vectors"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'doc': 'Document chunk 1'}),
MagicMock(metadata={'doc': 'Document chunk 2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'doc': 'Document chunk 3'}),
MagicMock(metadata={'doc': 'Document chunk 4'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both queries
assert len(chunks) == 4
assert 'Document chunk 1' in chunks
assert 'Document chunk 2' in chunks
assert 'Document chunk 3' in chunks
assert 'Document chunk 4' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index with many results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': f'Document chunk {i}'}) for i in range(10)
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
call_args = mock_index.query.call_args
assert call_args[1]['top_k'] == 2
# Results should contain all returned chunks (limit is applied by Pinecone)
assert len(chunks) == 10
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
mock_index_4d.query.return_value = mock_results_4d
chunks = await processor.query_document_embeddings(message)
# Verify different indexes were used
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
# Verify results from both dimensions
assert 'Document from 2D index' in chunks
assert 'Document from 4D index' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
mock_index.query.assert_not_called()
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify empty results
assert chunks == []
@pytest.mark.asyncio
async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'Document with Unicode: éñ中文🚀'}),
MagicMock(metadata={'doc': 'Regular ASCII document'})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
assert 'Document with Unicode: éñ中文🚀' in chunks
assert 'Regular ASCII document' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Create a large document content
large_content = "A" * 10000 # 10KB of content
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': large_content})
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify large content is properly handled
assert len(chunks) == 1
assert chunks[0] == large_content
@pytest.mark.asyncio
async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'doc': 'Short text'}),
MagicMock(metadata={'doc': 'A' * 1000}), # Long text
MagicMock(metadata={'doc': 'Text with numbers: 123 and symbols: @#$'}),
MagicMock(metadata={'doc': ' Whitespace text '}),
MagicMock(metadata={'doc': ''}) # Empty string
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
# Verify all content types are properly handled
assert len(chunks) == 5
assert 'Short text' in chunks
assert 'A' * 1000 in chunks
assert 'Text with numbers: 123 and symbols: @#$' in chunks
assert ' Whitespace text ' in chunks
assert '' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
"""Test that results from multiple vectors are properly accumulated"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Each query returns different results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'doc': 'Doc from vector 1.1'}),
MagicMock(metadata={'doc': 'Doc from vector 1.2'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'doc': 'Doc from vector 2.1'})
]
mock_results3 = MagicMock()
mock_results3.matches = [
MagicMock(metadata={'doc': 'Doc from vector 3.1'}),
MagicMock(metadata={'doc': 'Doc from vector 3.2'}),
MagicMock(metadata={'doc': 'Doc from vector 3.3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
# Verify all queries were made
assert mock_index.query.call_count == 3
# Verify all results are accumulated
assert len(chunks) == 6
assert 'Doc from vector 1.1' in chunks
assert 'Doc from vector 1.2' in chunks
assert 'Doc from vector 2.1' in chunks
assert 'Doc from vector 3.1' in chunks
assert 'Doc from vector 3.2' in chunks
assert 'Doc from vector 3.3' in chunks
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.query.doc_embeddings.pinecone.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.doc_embeddings.pinecone.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks. Pinecone implementation.\n"
)

View file

@ -0,0 +1,542 @@
"""
Unit tests for trustgraph.query.doc_embeddings.qdrant.service
Testing document embeddings query functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
"""Test Qdrant document embeddings query functionality"""
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-query-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-doc-query-processor'
# No store_uri or api_key provided - should use defaults
}
# Act
processor = Processor(**config)
# Assert
# Verify QdrantClient was created with default URI and None API key
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with single vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'first document chunk'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'second document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called with correct parameters
expected_collection = 'd_test_user_test_collection_3'
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
limit=5, # Direct limit, no multiplication
with_payload=True
)
# Verify result contains expected documents
assert len(result) == 2
# Results should be strings (document chunks)
assert isinstance(result[0], str)
assert isinstance(result[1], str)
# Verify content
assert result[0] == 'first document chunk'
assert result[1] == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from vector 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from vector 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'doc': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 'd_multi_user_multi_collection_2'
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results from both vectors are combined
assert len(result) == 3
assert 'document from vector 1' in result
assert 'document from vector 2' in result
assert 'another document from vector 2' in result
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings respects limit parameter"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with many results
mock_points = []
for i in range(10):
mock_point = MagicMock()
mock_point.payload = {'doc': f'document chunk {i}'}
mock_points.append(mock_point)
mock_response = MagicMock()
mock_response.points = mock_points
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called with exact limit (no multiplication)
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 3 # Direct limit
# Verify result contains all returned documents (limit applied by Qdrant)
assert len(result) == 10 # All results returned by mock
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with empty results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock empty query response
mock_response = MagicMock()
mock_response.points = []
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
assert result == []
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from 2D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
assert len(result) == 2
assert 'document from 2D vector' in result
assert 'document from 3D vector' in result
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_utf8_encoding(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with UTF-8 content"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with UTF-8 content
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'utf8_user'
mock_message.collection = 'utf8_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
assert len(result) == 2
# Verify UTF-8 content works correctly
assert 'Document with UTF-8: café, naïve, résumé' in result
assert 'Chinese text: 你好世界' in result
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings handles Qdrant errors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock Qdrant error
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_document_embeddings(mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with zero limit"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response
mock_point = MagicMock()
mock_point.payload = {'doc': 'document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Should still query (with limit 0)
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0
# Result should contain all returned documents
assert len(result) == 1
assert result[0] == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_large_limit(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with large limit"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with fewer results than limit
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document 2'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with large limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 1000 # Large limit
mock_message.user = 'large_user'
mock_message.collection = 'large_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Should query with full limit
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000
# Result should contain all available documents
assert len(result) == 2
assert 'document 1' in result
assert 'document 2' in result
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_missing_payload(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with missing payload data"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with missing 'doc' key
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'valid document'}
mock_point2 = MagicMock()
mock_point2.payload = {} # Missing 'doc' key
mock_point3 = MagicMock()
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'payload_user'
mock_message.collection = 'payload_collection'
# Act & Assert
# This should raise a KeyError when trying to access payload['doc']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
# Verify processor-specific arguments were added
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,484 @@
"""
Tests for Milvus graph embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Value, GraphEmbeddingsRequest
class TestMilvusGraphEmbeddingsQueryProcessor:
"""Test cases for Milvus graph embeddings query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') as mock_entity_vectors:
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=MagicMock(),
id='test-milvus-ge-query',
store_uri='http://localhost:19530'
)
return processor
@pytest.fixture
def mock_query_request(self):
"""Create a mock query request for testing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=10
)
return query
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(taskgroup=taskgroup_mock)
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
assert processor.vecstore == mock_vecstore
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_vecstore = MagicMock()
mock_entity_vectors.return_value = mock_vecstore
processor = Processor(
taskgroup=taskgroup_mock,
store_uri='http://custom-milvus:19530'
)
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
assert processor.vecstore == mock_vecstore
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "literal entity"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with correct parameters
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
# Verify results are converted to Value objects
assert len(result) == 3
assert isinstance(result[0], Value)
assert result[0].value == "http://example.com/entity1"
assert result[0].is_uri is True
assert isinstance(result[1], Value)
assert result[1].value == "http://example.com/entity2"
assert result[1].is_uri is True
assert isinstance(result[2], Value)
assert result[2].value == "literal entity"
assert result[2].is_uri is False
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters
expected_calls = [
(([0.1, 0.2, 0.3],), {"limit": 6}),
(([0.4, 0.5, 0.6],), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
assert len(result) == 3
entity_values = [r.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_with_limit(self, processor):
"""Test querying graph embeddings respects limit parameter"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=2
)
# Mock search results - more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
{"entity": {"entity": "http://example.com/entity4"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
# Verify results are limited to the requested limit
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
# Verify results are limited
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors(self, processor):
"""Test querying graph embeddings with empty vectors list"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
limit=5
)
result = await processor.query_graph_embeddings(query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_search_results(self, processor):
"""Test querying graph embeddings with empty search results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
# Verify search was called
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
# Verify empty results
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
"""Test querying graph embeddings with mixed URI and literal results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search results with mixed types
mock_results = [
{"entity": {"entity": "http://example.com/uri_entity"}},
{"entity": {"entity": "literal entity text"}},
{"entity": {"entity": "https://example.com/another_uri"}},
{"entity": {"entity": "another literal"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.is_uri]
assert len(uri_results) == 2
uri_values = [r.value for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.is_uri]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
# Mock search to raise exception
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'store_uri')
assert args.store_uri == 'http://localhost:19530'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--store-uri', 'http://custom-milvus:19530'
])
assert args.store_uri == 'http://custom-milvus:19530'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
assert args.store_uri == 'http://short-milvus:19530'
@patch('trustgraph.query.graph_embeddings.milvus.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.graph_embeddings.milvus.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph embeddings query service. Input is vector, output is list of\nentities\n"
)
@pytest.mark.asyncio
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying graph embeddings with zero limit"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
limit=0
)
result = await processor.query_graph_embeddings(query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
# Verify empty results due to zero limit
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying graph embeddings with different vector dimensions"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values

View file

@ -0,0 +1,507 @@
"""
Tests for Pinecone graph embeddings query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Value
class TestPineconeGraphEmbeddingsQueryProcessor:
"""Test cases for Pinecone graph embeddings query processor"""
@pytest.fixture
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
return message
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
processor = Processor(
taskgroup=MagicMock(),
id='test-pinecone-ge-query',
api_key='test-api-key'
)
return processor
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'env-api-key')
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
"""Test processor initialization with default parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
assert processor.pinecone == mock_pinecone
assert processor.api_key == 'env-api-key'
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
"""Test processor initialization with custom parameters"""
mock_pinecone = MagicMock()
mock_pinecone_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='custom-api-key'
)
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
assert processor.api_key == 'custom-api-key'
@patch('trustgraph.query.graph_embeddings.pinecone.service.PineconeGRPC')
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
"""Test processor initialization with custom URL (GRPC mode)"""
mock_pinecone = MagicMock()
mock_pinecone_grpc_class.return_value = mock_pinecone
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
api_key='test-api-key',
url='https://custom-host.pinecone.io'
)
mock_pinecone_grpc_class.assert_called_once_with(
api_key='test-api-key',
host='https://custom-host.pinecone.io'
)
assert processor.pinecone == mock_pinecone
assert processor.url == 'https://custom-host.pinecone.io'
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'not-specified')
def test_processor_initialization_missing_api_key(self):
"""Test processor initialization fails with missing API key"""
taskgroup_mock = MagicMock()
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
Processor(taskgroup=taskgroup_mock)
def test_create_value_uri(self, processor):
"""Test create_value method for URI entities"""
uri_entity = "http://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert value.value == uri_entity
assert value.is_uri == True
def test_create_value_https_uri(self, processor):
"""Test create_value method for HTTPS URI entities"""
uri_entity = "https://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert value.value == uri_entity
assert value.is_uri == True
def test_create_value_literal(self, processor):
"""Test create_value method for literal entities"""
literal_entity = "literal_entity"
value = processor.create_value(literal_entity)
assert isinstance(value, Value)
assert value.value == literal_entity
assert value.is_uri == False
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'http://example.org/entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'http://example.org/entity3'})
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "t-test_user-test_collection-3"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
mock_index.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
top_k=6, # 2 * limit
include_values=False,
include_metadata=True
)
# Verify results
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].is_uri == True
assert entities[1].value == 'entity2'
assert entities[1].is_uri == False
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].is_uri == True
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify deduplication occurred
entity_values = [e.value for e in entities]
assert len(entity_values) == 3
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
# Mock index with many results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': f'entity{i}'}) for i in range(10)
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify limit is respected
assert len(entities) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index_4d.query.return_value = mock_results_4d
entities = await processor.query_graph_embeddings(message)
# Verify different indexes were used
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
# Verify results from both dimensions
entity_values = [e.value for e in entities]
assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
mock_index.query.assert_not_called()
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_results = MagicMock()
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify empty results
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
"""Test that deduplication works correctly across multiple vector queries"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
entity_values = [e.value for e in entities]
assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit
mock_results1 = MagicMock()
mock_results1.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.return_value = mock_results1
entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached
mock_index.query.assert_called_once()
assert len(entities) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'api_key')
assert args.api_key == 'not-specified' # Default value when no env var
assert hasattr(args, 'url')
assert args.url is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--api-key', 'custom-api-key',
'--url', 'https://custom-host.pinecone.io'
])
assert args.api_key == 'custom-api-key'
assert args.url == 'https://custom-host.pinecone.io'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args([
'-a', 'short-api-key',
'-u', 'https://short-host.pinecone.io'
])
assert args.api_key == 'short-api-key'
assert args.url == 'https://short-host.pinecone.io'
@patch('trustgraph.query.graph_embeddings.pinecone.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.graph_embeddings.pinecone.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nGraph embeddings query service. Input is vector, output is list of\nentities. Pinecone implementation.\n"
)

View file

@ -0,0 +1,537 @@
"""
Unit tests for trustgraph.query.graph_embeddings.qdrant.service
Testing graph embeddings query functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
"""Test Qdrant graph embeddings query functionality"""
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-graph-query-processor'
}
# Act
processor = Processor(**config)
# Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-graph-query-processor'
# No store_uri or api_key provided - should use defaults
}
# Act
processor = Processor(**config)
# Assert
# Verify QdrantClient was created with default URI and None API key
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_create_value_http_uri(self, mock_base_init, mock_qdrant_client):
"""Test create_value with HTTP URI"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
value = processor.create_value('http://example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'http://example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_create_value_https_uri(self, mock_base_init, mock_qdrant_client):
"""Test create_value with HTTPS URI"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
value = processor.create_value('https://secure.example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'https://secure.example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_create_value_regular_string(self, mock_base_init, mock_qdrant_client):
"""Test create_value with regular string (non-URI)"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
value = processor.create_value('regular entity name')
# Assert
assert hasattr(value, 'value')
assert value.value == 'regular entity name'
assert hasattr(value, 'is_uri')
assert value.is_uri == False
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with single vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'entity1'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'entity2'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called with correct parameters
expected_collection = 't_test_user_test_collection_3'
mock_qdrant_instance.query_points.assert_called_once_with(
collection_name=expected_collection,
query=[0.1, 0.2, 0.3],
limit=10, # limit * 2 for deduplication
with_payload=True
)
# Verify result contains expected entities
assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result)
entity_values = [entity.value for entity in result]
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with multiple vectors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'entity1'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'entity2'}
mock_point3 = MagicMock()
mock_point3.payload = {'entity': 'entity3'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1, mock_point2]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify both collections were queried
expected_collection = 't_multi_user_multi_collection_2'
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify deduplication - entity2 appears in both results but should only appear once
entity_values = [entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings respects limit parameter"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with more results than limit
mock_points = []
for i in range(10):
mock_point = MagicMock()
mock_point.payload = {'entity': f'entity{i}'}
mock_points.append(mock_point)
mock_response = MagicMock()
mock_response.points = mock_points
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called with limit * 2
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 6 # 3 * 2
# Verify result is limited to requested limit
assert len(result) == 3
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with empty results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock empty query response
mock_response = MagicMock()
mock_response.points = []
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
assert result == []
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with different vector dimensions"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'entity2d'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'entity3d'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_uri_detection(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with URI detection"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with URIs and regular strings
mock_point1 = MagicMock()
mock_point1.payload = {'entity': 'http://example.com/entity1'}
mock_point2 = MagicMock()
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
mock_point3 = MagicMock()
mock_point3.payload = {'entity': 'regular entity'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
assert len(uri_entities) == 2
uri_values = [entity.value for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings handles Qdrant errors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock Qdrant error
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_graph_embeddings(mock_message)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_query_graph_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
"""Test querying graph embeddings with zero limit"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response - even with zero limit, Qdrant might return results
mock_point = MagicMock()
mock_point.payload = {'entity': 'entity1'}
mock_response = MagicMock()
mock_response.points = [mock_point]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Should still query (with limit 0)
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0 # 0 * 2 = 0
# With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1
assert result[0].value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
# Verify processor-specific arguments were added
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,539 @@
"""
Tests for Cassandra triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.cassandra.service import Processor
from trustgraph.schema import Value
class TestCassandraQueryProcessor:
"""Test cases for Cassandra query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
return Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
graph_host='localhost'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_spo_query(self, mock_trustgraph):
"""Test querying triples with subject, predicate, and object specified"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
processor = Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
graph_host='localhost'
)
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify TrustGraph was created with correct parameters
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user',
table='test_collection'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
'test_subject', 'test_predicate', 'test_object', limit=100
)
# Verify result contains the queried triple
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'test_object'
def test_processor_initialization_with_defaults(self):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
assert processor.graph_host == ['localhost']
assert processor.username is None
assert processor.password is None
assert processor.table is None
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='cassandra.example.com',
graph_username='queryuser',
graph_password='querypass'
)
assert processor.graph_host == ['cassandra.example.com']
assert processor.username == 'queryuser'
assert processor.password == 'querypass'
assert processor.table is None
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_sp_pattern(self, mock_trustgraph):
"""Test SP query pattern (subject and predicate, no object)"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph and response
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.o = 'result_object'
mock_tg_instance.get_sp.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=None,
limit=50
)
result = await processor.query_triples(query)
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_s_pattern(self, mock_trustgraph):
"""Test S query pattern (subject only)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_result.o = 'result_object'
mock_tg_instance.get_s.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=None,
o=None,
limit=25
)
result = await processor.query_triples(query)
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_p_pattern(self, mock_trustgraph):
"""Test P query pattern (predicate only)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.o = 'result_object'
mock_tg_instance.get_p.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
o=None,
limit=10
)
result = await processor.query_triples(query)
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_o_pattern(self, mock_trustgraph):
"""Test O query pattern (object only)"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.p = 'result_predicate'
mock_tg_instance.get_o.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value='test_object', is_uri=False),
limit=75
)
result = await processor.query_triples(query)
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
"""Test query pattern with no constraints (get all)"""
from trustgraph.schema import TriplesQueryRequest
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'all_subject'
mock_result.p = 'all_predicate'
mock_result.o = 'all_object'
mock_tg_instance.get_all.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=1000
)
result = await processor.query_triples(query)
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
assert result[0].p.value == 'all_predicate'
assert result[0].o.value == 'all_object'
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once_with(parser)
# Verify our specific arguments were added
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'localhost'
assert hasattr(args, 'graph_username')
assert args.graph_username is None
assert hasattr(args, 'graph_password')
assert args.graph_password is None
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
parser = ArgumentParser()
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'query.cassandra.com',
'--graph-username', 'queryuser',
'--graph-password', 'querypass'
])
assert args.graph_host == 'query.cassandra.com'
assert args.graph_username == 'queryuser'
assert args.graph_password == 'querypass'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
parser = ArgumentParser()
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'short.query.com'])
assert args.graph_host == 'short.query.com'
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.cassandra.service import run, default_ident
run()
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_with_authentication(self, mock_trustgraph):
"""Test querying with username and password authentication"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
processor = Processor(
taskgroup=MagicMock(),
graph_username='authuser',
graph_password='authpass'
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=100
)
await processor.query_triples(query)
# Verify TrustGraph was created with authentication
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user',
table='test_collection',
username='authuser',
password='authpass'
)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_table_reuse(self, mock_trustgraph):
"""Test that TrustGraph is reused for same table"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=100
)
# First query should create TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_table_switching(self, mock_trustgraph):
"""Test table switching creates new TrustGraph"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
processor = Processor(taskgroup=MagicMock())
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Value(value='test_subject', is_uri=False),
p=None,
o=None,
limit=100
)
await processor.query_triples(query1)
assert processor.table == ('user1', 'collection1')
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Value(value='test_subject', is_uri=False),
p=None,
o=None,
limit=100
)
await processor.query_triples(query2)
assert processor.table == ('user2', 'collection2')
# Verify TrustGraph was created twice
assert mock_trustgraph.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_exception_handling(self, mock_trustgraph):
"""Test exception handling during query execution"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
limit=100
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
async def test_query_triples_multiple_results(self, mock_trustgraph):
"""Test query returning multiple results"""
from trustgraph.schema import TriplesQueryRequest, Value
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
# Mock multiple results
mock_result1 = MagicMock()
mock_result1.o = 'object1'
mock_result2 = MagicMock()
mock_result2.o = 'object2'
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=None,
limit=100
)
result = await processor.query_triples(query)
assert len(result) == 2
assert result[0].o.value == 'object1'
assert result[1].o.value == 'object2'

View file

@ -0,0 +1,556 @@
"""
Tests for FalkorDB triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.falkordb.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestFalkorDBQueryProcessor:
"""Test cases for FalkorDB query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.falkordb.service.FalkorDB'):
return Processor(
taskgroup=MagicMock(),
id='test-falkordb-query',
graph_url='falkor://localhost:6379'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'falkordb'
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
mock_client.select_graph.assert_called_once_with('falkordb')
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_custom_params(self, mock_falkordb):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
processor = Processor(
taskgroup=taskgroup_mock,
graph_url='falkor://custom:6379',
database='customdb'
)
assert processor.db == 'customdb'
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
mock_client.select_graph.assert_called_once_with('customdb')
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_falkordb):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results - both queries return one record each
mock_result = MagicMock()
mock_result.result_set = [["record1"]]
mock_graph.query.return_value = mock_result
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_falkordb):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different objects
mock_result1 = MagicMock()
mock_result1.result_set = [["literal result"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/uri_result"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_so_query(self, mock_falkordb):
"""Test SO query (subject and object specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different predicates
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/pred1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/pred2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_s_query(self, mock_falkordb):
"""Test S query (subject only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different predicate-object pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/pred1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/pred2", "http://example.com/uri2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_po_query(self, mock_falkordb):
"""Test PO query (predicate and object specified)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subjects
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_p_query(self, mock_falkordb):
"""Test P query (predicate only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subject-object pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/uri2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_o_query(self, mock_falkordb):
"""Test O query (object only)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results with different subject-predicate pairs
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/subj1", "http://example.com/pred1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/pred2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_falkordb):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query results
mock_result1 = MagicMock()
mock_result1.result_set = [["http://example.com/s1", "http://example.com/p1", "literal1"]]
mock_result2 = MagicMock()
mock_result2.result_set = [["http://example.com/s2", "http://example.com/p2", "http://example.com/o2"]]
mock_graph.query.side_effect = [mock_result1, mock_result2]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_falkordb):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_client = MagicMock()
mock_graph = MagicMock()
mock_falkordb.from_url.return_value = mock_client
mock_client.select_graph.return_value = mock_graph
# Mock query to raise exception
mock_graph.query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_url')
assert args.graph_url == 'falkor://falkordb:6379'
assert hasattr(args, 'database')
assert args.database == 'falkordb'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-url', 'falkor://custom:6379',
'--database', 'querydb'
])
assert args.graph_url == 'falkor://custom:6379'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'falkor://short:6379'])
assert args.graph_url == 'falkor://short:6379'
@patch('trustgraph.query.triples.falkordb.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.falkordb.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for FalkorDB.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)

View file

@ -0,0 +1,568 @@
"""
Tests for Memgraph triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestMemgraphQueryProcessor:
"""Test cases for Memgraph query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.memgraph.service.GraphDatabase'):
return Processor(
taskgroup=MagicMock(),
id='test-memgraph-query',
graph_host='bolt://localhost:7687'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'memgraph'
mock_graph_db.driver.assert_called_once_with(
'bolt://memgraph:7687',
auth=('memgraph', 'password')
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='queryuser',
password='querypass',
database='customdb'
)
assert processor.db == 'customdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('queryuser', 'querypass')
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_graph_db):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results - both queries return one record each
mock_records = [MagicMock()]
mock_driver.execute_query.return_value = (mock_records, None, None)
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_graph_db):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different objects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"dest": "literal result"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_so_query(self, mock_graph_db):
"""Test SO query (subject and object specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different predicates
mock_record1 = MagicMock()
mock_record1.data.return_value = {"rel": "http://example.com/pred1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"rel": "http://example.com/pred2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_s_query(self, mock_graph_db):
"""Test S query (subject only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different predicate-object pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"rel": "http://example.com/pred1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"rel": "http://example.com/pred2", "dest": "http://example.com/uri2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_po_query(self, mock_graph_db):
"""Test PO query (predicate and object specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subjects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_p_query(self, mock_graph_db):
"""Test P query (predicate only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subject-object pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2", "dest": "http://example.com/uri2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_o_query(self, mock_graph_db):
"""Test O query (object only)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different subject-predicate pairs
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/subj1", "rel": "http://example.com/pred1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/subj2", "rel": "http://example.com/pred2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_graph_db):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_graph_db):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock execute_query to raise exception
mock_driver.execute_query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
assert args.username == 'memgraph'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'memgraph'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'queryuser',
'--password', 'querypass',
'--database', 'querydb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'queryuser'
assert args.password == 'querypass'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.query.triples.memgraph.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.memgraph.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for memgraph.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)

View file

@ -0,0 +1,338 @@
"""
Tests for Neo4j triples query service
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
class TestNeo4jQueryProcessor:
"""Test cases for Neo4j query processor"""
@pytest.fixture
def processor(self):
"""Create a processor instance for testing"""
with patch('trustgraph.query.triples.neo4j.service.GraphDatabase'):
return Processor(
taskgroup=MagicMock(),
id='test-neo4j-query',
graph_host='bolt://localhost:7687'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert result.value == "just a literal string"
assert result.is_uri is False
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert result.value == ""
assert result.is_uri is False
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert result.value == "http"
assert result.is_uri is False
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
"""Test processor initialization with default parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(taskgroup=taskgroup_mock)
assert processor.db == 'neo4j'
mock_graph_db.driver.assert_called_once_with(
'bolt://neo4j:7687',
auth=('neo4j', 'password')
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_custom_params(self, mock_graph_db):
"""Test processor initialization with custom parameters"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = Processor(
taskgroup=taskgroup_mock,
graph_host='bolt://custom:7687',
username='queryuser',
password='querypass',
database='customdb'
)
assert processor.db == 'customdb'
mock_graph_db.driver.assert_called_once_with(
'bolt://custom:7687',
auth=('queryuser', 'querypass')
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_spo_query(self, mock_graph_db):
"""Test SPO query (all values specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results - both queries return one record each
mock_records = [MagicMock()]
mock_driver.execute_query.return_value = (mock_records, None, None)
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_sp_query(self, mock_graph_db):
"""Test SP query (subject and predicate specified)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results with different objects
mock_record1 = MagicMock()
mock_record1.data.return_value = {"dest": "literal result"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_wildcard_query(self, mock_graph_db):
"""Test wildcard query (no constraints)"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock query results
mock_record1 = MagicMock()
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
mock_record2 = MagicMock()
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
mock_driver.execute_query.side_effect = [
([mock_record1], None, None), # Literal query
([mock_record2], None, None) # URI query
]
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=None,
limit=100
)
result = await processor.query_triples(query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_exception_handling(self, mock_graph_db):
"""Test exception handling during query processing"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
# Mock execute_query to raise exception
mock_driver.execute_query.side_effect = Exception("Database connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added
# Parse empty args to check defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://neo4j:7687'
assert hasattr(args, 'username')
assert args.username == 'neo4j'
assert hasattr(args, 'password')
assert args.password == 'password'
assert hasattr(args, 'database')
assert args.database == 'neo4j'
def test_add_args_with_custom_values(self):
"""Test add_args with custom command line values"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with custom values
args = parser.parse_args([
'--graph-host', 'bolt://custom:7687',
'--username', 'queryuser',
'--password', 'querypass',
'--database', 'querydb'
])
assert args.graph_host == 'bolt://custom:7687'
assert args.username == 'queryuser'
assert args.password == 'querypass'
assert args.database == 'querydb'
def test_add_args_short_form(self):
"""Test add_args with short form arguments"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
Processor.add_args(parser)
# Test parsing with short form
args = parser.parse_args(['-g', 'bolt://short:7687'])
assert args.graph_host == 'bolt://short:7687'
@patch('trustgraph.query.triples.neo4j.service.Processor.launch')
def test_run_function(self, mock_launch):
"""Test the run function calls Processor.launch with correct parameters"""
from trustgraph.query.triples.neo4j.service import run, default_ident
run()
mock_launch.assert_called_once_with(
default_ident,
"\nTriples query service for neo4j.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
)