mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Release/v1.2 (#457)
* Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
This commit is contained in:
parent
c85ba197be
commit
89be656990
509 changed files with 49632 additions and 5159 deletions
148
tests/unit/test_query/conftest.py
Normal file
148
tests/unit/test_query/conftest.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Shared fixtures for query tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_query_config():
|
||||
"""Base configuration for query processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-query-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_query_config(base_query_config):
|
||||
"""Configuration for Qdrant query processors"""
|
||||
return base_query_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.query_points.return_value = []
|
||||
return mock_client
|
||||
|
||||
|
||||
# Graph embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_request():
|
||||
"""Mock graph embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_vectors():
|
||||
"""Mock graph embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_query_response():
|
||||
"""Mock graph embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_uri_response():
|
||||
"""Mock graph embeddings query response with URIs"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
return [mock_point1, mock_point2, mock_point3]
|
||||
|
||||
|
||||
# Document embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_request():
|
||||
"""Mock document embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_query_response():
|
||||
"""Mock document embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_utf8_response():
|
||||
"""Mock document embeddings query response with UTF-8 content"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_query_response():
|
||||
"""Mock empty query response"""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_large_query_response():
|
||||
"""Mock large query response with many results"""
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
return mock_points
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mixed_dimension_vectors():
|
||||
"""Mock request with vectors of different dimensions"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
return mock_message
|
||||
456
tests/unit/test_query/test_doc_embeddings_milvus_query.py
Normal file
456
tests/unit/test_query/test_doc_embeddings_milvus_query.py
Normal file
|
|
@ -0,0 +1,456 @@
|
|||
"""
|
||||
Tests for Milvus document embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsQueryProcessor:
|
||||
"""Test cases for Milvus document embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors') as mock_doc_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-de-query',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_doc_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.DocVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_doc_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_doc_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_doc_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"doc": "First document chunk"}},
|
||||
{"entity": {"doc": "Second document chunk"}},
|
||||
{"entity": {"doc": "Third document chunk"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
|
||||
|
||||
# Verify results are document chunks
|
||||
assert len(result) == 3
|
||||
assert result[0] == "First document chunk"
|
||||
assert result[1] == "Second document chunk"
|
||||
assert result[2] == "Third document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document from first vector"}},
|
||||
{"entity": {"doc": "Another doc from first vector"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document from second vector"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3],), {"limit": 3}),
|
||||
(([0.4, 0.5, 0.6],), {"limit": 3}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results from all vectors are combined
|
||||
assert len(result) == 3
|
||||
assert "Document from first vector" in result
|
||||
assert "Another doc from first vector" in result
|
||||
assert "Document from second vector" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document 1"}},
|
||||
{"entity": {"doc": "Document 2"}},
|
||||
{"entity": {"doc": "Document 3"}},
|
||||
{"entity": {"doc": "Document 4"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called with the specified limit
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2)
|
||||
|
||||
# Verify all results are returned (Milvus handles limit internally)
|
||||
assert len(result) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying document embeddings with empty vectors list"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying document embeddings with empty search results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5)
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_unicode_documents(self, processor):
|
||||
"""Test querying document embeddings with Unicode document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with Unicode content
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
|
||||
{"entity": {"doc": "Regular ASCII document"}},
|
||||
{"entity": {"doc": "Document with émojis: 😀🎉"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify Unicode content is preserved
|
||||
assert len(result) == 3
|
||||
assert "Document with Unicode: éñ中文🚀" in result
|
||||
assert "Regular ASCII document" in result
|
||||
assert "Document with émojis: 😀🎉" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
"""Test querying document embeddings with large document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with large content
|
||||
large_doc = "A" * 10000 # 10KB of content
|
||||
mock_results = [
|
||||
{"entity": {"doc": large_doc}},
|
||||
{"entity": {"doc": "Small document"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify large content is preserved
|
||||
assert len(result) == 2
|
||||
assert large_doc in result
|
||||
assert "Small document" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
"""Test querying document embeddings with special characters in documents"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with special characters
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
|
||||
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
|
||||
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify special characters are preserved
|
||||
assert len(result) == 3
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in result
|
||||
assert "Document with\nnewlines\tand\ttabs" in result
|
||||
assert "Document with special chars: @#$%^&*()" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to zero limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying document embeddings with negative limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for negative limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to negative limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search to raise exception
|
||||
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_document_embeddings(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
|
||||
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
|
||||
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
assert "Document from 2D vector" in result
|
||||
assert "Document from 4D vector" in result
|
||||
assert "Document from 3D vector" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_duplicate_documents(self, processor):
|
||||
"""Test querying document embeddings with duplicate documents in results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates across vectors
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document A"}},
|
||||
{"entity": {"doc": "Document B"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document B"}}, # Duplicate
|
||||
{"entity": {"doc": "Document C"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
|
||||
# This preserves ranking and allows multiple occurrences
|
||||
assert len(result) == 4
|
||||
assert result.count("Document B") == 2 # Should appear twice
|
||||
assert "Document A" in result
|
||||
assert "Document C" in result
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.milvus.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.milvus.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.doc_embeddings.milvus.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
|
||||
)
|
||||
558
tests/unit/test_query/test_doc_embeddings_pinecone_query.py
Normal file
558
tests/unit/test_query/test_doc_embeddings_pinecone_query.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""
|
||||
Tests for Pinecone document embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.pinecone.service import Processor
|
||||
|
||||
|
||||
class TestPineconeDocEmbeddingsQueryProcessor:
|
||||
"""Test cases for Pinecone document embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-de-query',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'First document chunk'}),
|
||||
MagicMock(metadata={'doc': 'Second document chunk'}),
|
||||
MagicMock(metadata={'doc': 'Third document chunk'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "d-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
mock_index.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
top_k=3,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0] == 'First document chunk'
|
||||
assert chunks[1] == 'Second document chunk'
|
||||
assert chunks[2] == 'Third document chunk'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'doc': 'Document chunk 1'}),
|
||||
MagicMock(metadata={'doc': 'Document chunk 2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'doc': 'Document chunk 3'}),
|
||||
MagicMock(metadata={'doc': 'Document chunk 4'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
chunks = await processor.query_document_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify results from both queries
|
||||
assert len(chunks) == 4
|
||||
assert 'Document chunk 1' in chunks
|
||||
assert 'Document chunk 2' in chunks
|
||||
assert 'Document chunk 3' in chunks
|
||||
assert 'Document chunk 4' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index with many results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': f'Document chunk {i}'}) for i in range(10)
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify limit is passed to query
|
||||
mock_index.query.assert_called_once()
|
||||
call_args = mock_index.query.call_args
|
||||
assert call_args[1]['top_k'] == 2
|
||||
|
||||
# Results should contain all returned chunks (limit is applied by Pinecone)
|
||||
assert len(chunks) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
|
||||
# Mock results for different dimensions
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
|
||||
mock_index_2d.query.return_value = mock_results_2d
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
|
||||
mock_index_4d.query.return_value = mock_results_4d
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify different indexes were used
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
mock_index_2d.query.assert_called_once()
|
||||
mock_index_4d.query.assert_called_once()
|
||||
|
||||
# Verify results from both dimensions
|
||||
assert 'Document from 2D index' in chunks
|
||||
assert 'Document from 4D index' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.query.assert_not_called()
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify empty results
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_unicode_content(self, processor):
|
||||
"""Test querying document embeddings with Unicode content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'Document with Unicode: éñ中文🚀'}),
|
||||
MagicMock(metadata={'doc': 'Regular ASCII document'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify Unicode content is properly handled
|
||||
assert len(chunks) == 2
|
||||
assert 'Document with Unicode: éñ中文🚀' in chunks
|
||||
assert 'Regular ASCII document' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_content(self, processor):
|
||||
"""Test querying document embeddings with large content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Create a large document content
|
||||
large_content = "A" * 10000 # 10KB of content
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': large_content})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify large content is properly handled
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == large_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
||||
"""Test querying document embeddings with mixed content types"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'doc': 'Short text'}),
|
||||
MagicMock(metadata={'doc': 'A' * 1000}), # Long text
|
||||
MagicMock(metadata={'doc': 'Text with numbers: 123 and symbols: @#$'}),
|
||||
MagicMock(metadata={'doc': ' Whitespace text '}),
|
||||
MagicMock(metadata={'doc': ''}) # Empty string
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify all content types are properly handled
|
||||
assert len(chunks) == 5
|
||||
assert 'Short text' in chunks
|
||||
assert 'A' * 1000 in chunks
|
||||
assert 'Text with numbers: 123 and symbols: @#$' in chunks
|
||||
assert ' Whitespace text ' in chunks
|
||||
assert '' in chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
"""Test handling of index access failure"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
processor.pinecone.Index.side_effect = Exception("Index access failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index access failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
"""Test that results from multiple vectors are properly accumulated"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Each query returns different results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 1.1'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 1.2'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 2.1'})
|
||||
]
|
||||
|
||||
mock_results3 = MagicMock()
|
||||
mock_results3.matches = [
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.1'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.2'}),
|
||||
MagicMock(metadata={'doc': 'Doc from vector 3.3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify all queries were made
|
||||
assert mock_index.query.call_count == 3
|
||||
|
||||
# Verify all results are accumulated
|
||||
assert len(chunks) == 6
|
||||
assert 'Doc from vector 1.1' in chunks
|
||||
assert 'Doc from vector 1.2' in chunks
|
||||
assert 'Doc from vector 2.1' in chunks
|
||||
assert 'Doc from vector 3.1' in chunks
|
||||
assert 'Doc from vector 3.2' in chunks
|
||||
assert 'Doc from vector 3.3' in chunks
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.doc_embeddings.pinecone.service.DocumentEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.pinecone.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.doc_embeddings.pinecone.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks. Pinecone implementation.\n"
|
||||
)
|
||||
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.doc_embeddings.qdrant.service
|
||||
Testing document embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=5, # Direct limit, no multiplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 'd_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with many results
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 3 # Direct limit
|
||||
|
||||
# Verify result contains all returned documents (limit applied by Qdrant)
|
||||
assert len(result) == 10 # All results returned by mock
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_utf8_encoding(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with UTF-8 content"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with UTF-8 content
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': 'document chunk'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_large_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with large limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with fewer results than limit
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document 2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_missing_payload(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with missing payload data"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with missing 'doc' key
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'valid document'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {} # Missing 'doc' key
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
||||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['doc']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
484
tests/unit/test_query/test_graph_embeddings_milvus_query.py
Normal file
484
tests/unit/test_query/test_graph_embeddings_milvus_query.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
"""
|
||||
Tests for Milvus graph embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Value, GraphEmbeddingsRequest
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
"""Test cases for Milvus graph embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors') as mock_entity_vectors:
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-milvus-ge-query',
|
||||
store_uri='http://localhost:19530'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
|
||||
def test_processor_initialization_with_defaults(self, mock_entity_vectors):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://localhost:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.EntityVectors')
|
||||
def test_processor_initialization_with_custom_params(self, mock_entity_vectors):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_vecstore = MagicMock()
|
||||
mock_entity_vectors.return_value = mock_vecstore
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
store_uri='http://custom-milvus:19530'
|
||||
)
|
||||
|
||||
mock_entity_vectors.assert_called_once_with('http://custom-milvus:19530')
|
||||
assert processor.vecstore == mock_vecstore
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "literal entity"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
|
||||
|
||||
# Verify results are converted to Value objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Value)
|
||||
assert result[0].value == "http://example.com/entity1"
|
||||
assert result[0].is_uri is True
|
||||
assert isinstance(result[1], Value)
|
||||
assert result[1].value == "http://example.com/entity2"
|
||||
assert result[1].is_uri is True
|
||||
assert isinstance(result[2], Value)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3],), {"limit": 6}),
|
||||
(([0.4, 0.5, 0.6],), {"limit": 6}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results are deduplicated and limited
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_with_limit(self, processor):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
{"entity": {"entity": "http://example.com/entity4"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with 2*limit for better deduplication
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
|
||||
|
||||
# Verify results are limited to the requested limit
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication(self, processor):
|
||||
"""Test that duplicate entities are properly deduplicated"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}}, # New
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - first vector returns enough results
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results_1
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify only first vector was searched (limit reached)
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4)
|
||||
|
||||
# Verify results are limited
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying graph embeddings with empty vectors list"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying graph embeddings with empty search results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10)
|
||||
|
||||
# Verify empty results
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
|
||||
"""Test querying graph embeddings with mixed URI and literal results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with mixed types
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/uri_entity"}},
|
||||
{"entity": {"entity": "literal entity text"}},
|
||||
{"entity": {"entity": "https://example.com/another_uri"}},
|
||||
{"entity": {"entity": "another literal"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.is_uri]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.value for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.is_uri]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
assert "another literal" in literal_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search to raise exception
|
||||
processor.vecstore.search.side_effect = Exception("Milvus connection failed")
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_graph_embeddings(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'store_uri')
|
||||
assert args.store_uri == 'http://localhost:19530'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--store-uri', 'http://custom-milvus:19530'
|
||||
])
|
||||
|
||||
assert args.store_uri == 'http://custom-milvus:19530'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.milvus.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-t', 'http://short-milvus:19530'])
|
||||
|
||||
assert args.store_uri == 'http://short-milvus:19530'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.milvus.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.graph_embeddings.milvus.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph embeddings query service. Input is vector, output is list of\nentities\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
||||
# Verify empty results due to zero limit
|
||||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
|
||||
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
|
||||
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
507
tests/unit/test_query/test_graph_embeddings_pinecone_query.py
Normal file
507
tests/unit/test_query/test_graph_embeddings_pinecone_query.py
Normal file
|
|
@ -0,0 +1,507 @@
|
|||
"""
|
||||
Tests for Pinecone graph embeddings query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
"""Test cases for Pinecone graph embeddings query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone') as mock_pinecone_class:
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-pinecone-ge-query',
|
||||
api_key='test-api-key'
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'env-api-key')
|
||||
def test_processor_initialization_with_defaults(self, mock_pinecone_class):
|
||||
"""Test processor initialization with default parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='env-api-key')
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.api_key == 'env-api-key'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Pinecone')
|
||||
def test_processor_initialization_with_custom_params(self, mock_pinecone_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='custom-api-key'
|
||||
)
|
||||
|
||||
mock_pinecone_class.assert_called_once_with(api_key='custom-api-key')
|
||||
assert processor.api_key == 'custom-api-key'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.PineconeGRPC')
|
||||
def test_processor_initialization_with_url(self, mock_pinecone_grpc_class):
|
||||
"""Test processor initialization with custom URL (GRPC mode)"""
|
||||
mock_pinecone = MagicMock()
|
||||
mock_pinecone_grpc_class.return_value = mock_pinecone
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
api_key='test-api-key',
|
||||
url='https://custom-host.pinecone.io'
|
||||
)
|
||||
|
||||
mock_pinecone_grpc_class.assert_called_once_with(
|
||||
api_key='test-api-key',
|
||||
host='https://custom-host.pinecone.io'
|
||||
)
|
||||
assert processor.pinecone == mock_pinecone
|
||||
assert processor.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.default_api_key', 'not-specified')
|
||||
def test_processor_initialization_missing_api_key(self):
|
||||
"""Test processor initialization fails with missing API key"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Pinecone API key must be specified"):
|
||||
Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
def test_create_value_uri(self, processor):
|
||||
"""Test create_value method for URI entities"""
|
||||
uri_entity = "http://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
|
||||
def test_create_value_https_uri(self, processor):
|
||||
"""Test create_value method for HTTPS URI entities"""
|
||||
uri_entity = "https://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
|
||||
def test_create_value_literal(self, processor):
|
||||
"""Test create_value method for literal entities"""
|
||||
literal_entity = "literal_entity"
|
||||
value = processor.create_value(literal_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert value.value == literal_entity
|
||||
assert value.is_uri == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'http://example.org/entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'http://example.org/entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "t-test_user-test_collection-3"
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
mock_index.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
top_k=6, # 2 * limit
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].is_uri == True
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].is_uri == False
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].is_uri == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify deduplication occurred
|
||||
entity_values = [e.value for e in entities]
|
||||
assert len(entity_values) == 3
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock index with many results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': f'entity{i}'}) for i in range(10)
|
||||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify limit is respected
|
||||
assert len(entities) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
if name.endswith("-2"):
|
||||
return mock_index_2d
|
||||
elif name.endswith("-4"):
|
||||
return mock_index_4d
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
|
||||
# Mock results for different dimensions
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
mock_index_2d.query.return_value = mock_results_2d
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
||||
mock_index_4d.query.return_value = mock_results_4d
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes were used
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
mock_index_2d.query.assert_called_once()
|
||||
mock_index_4d.query.assert_called_once()
|
||||
|
||||
# Verify results from both dimensions
|
||||
entity_values = [e.value for e in entities]
|
||||
assert 'entity_2d' in entity_values
|
||||
assert 'entity_4d' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
mock_index.query.assert_not_called()
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify empty results
|
||||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
||||
"""Test that deduplication works correctly across multiple vector queries"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Both queries return overlapping results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'}),
|
||||
MagicMock(metadata={'entity': 'entity4'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity5'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
entity_values = [e.value for e in entities]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query returns enough results to meet limit
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results1
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should only make one query since limit was reached
|
||||
mock_index.query.assert_called_once()
|
||||
assert len(entities) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'api_key')
|
||||
assert args.api_key == 'not-specified' # Default value when no env var
|
||||
assert hasattr(args, 'url')
|
||||
assert args.url is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--api-key', 'custom-api-key',
|
||||
'--url', 'https://custom-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'custom-api-key'
|
||||
assert args.url == 'https://custom-host.pinecone.io'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.graph_embeddings.pinecone.service.GraphEmbeddingsQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args([
|
||||
'-a', 'short-api-key',
|
||||
'-u', 'https://short-host.pinecone.io'
|
||||
])
|
||||
|
||||
assert args.api_key == 'short-api-key'
|
||||
assert args.url == 'https://short-host.pinecone.io'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.pinecone.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nGraph embeddings query service. Input is vector, output is list of\nentities. Pinecone implementation.\n"
|
||||
)
|
||||
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,537 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.graph_embeddings.qdrant.service
|
||||
Testing graph embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_http_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTP URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_https_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_regular_string(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with regular string (non-URI)"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('regular entity name')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=10, # limit * 2 for deduplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'entity3'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1, mock_point2]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 't_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with more results than limit
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': f'entity{i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 6 # 3 * 2
|
||||
|
||||
# Verify result is limited to requested limit
|
||||
assert len(result) == 3
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity2d'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity3d'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_uri_detection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with URI detection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with URIs and regular strings
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response - even with zero limit, Qdrant might return results
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': 'entity1'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0 # 0 * 2 = 0
|
||||
|
||||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
|
|
@ -0,0 +1,539 @@
|
|||
"""
|
||||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
"""Test cases for Cassandra query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
)
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='queryuser',
|
||||
graph_password='querypass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'queryuser'
|
||||
assert processor.password == 'querypass'
|
||||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'query.cassandra.com',
|
||||
'--graph-username', 'queryuser',
|
||||
'--graph-password', 'querypass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'query.cassandra.com'
|
||||
assert args.graph_username == 'queryuser'
|
||||
assert args.graph_password == 'querypass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.query.com'])
|
||||
|
||||
assert args.graph_host == 'short.query.com'
|
||||
|
||||
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.cassandra.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
graph_username='authuser',
|
||||
graph_password='authpass'
|
||||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection',
|
||||
username='authuser',
|
||||
password='authpass'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
556
tests/unit/test_query/test_triples_falkordb_query.py
Normal file
556
tests/unit/test_query/test_triples_falkordb_query.py
Normal file
|
|
@ -0,0 +1,556 @@
|
|||
"""
|
||||
Tests for FalkorDB triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.falkordb.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestFalkorDBQueryProcessor:
|
||||
"""Test cases for FalkorDB query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.falkordb.service.FalkorDB'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-falkordb-query',
|
||||
graph_url='falkor://localhost:6379'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'falkordb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://falkordb:6379')
|
||||
mock_client.select_graph.assert_called_once_with('falkordb')
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_custom_params(self, mock_falkordb):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_url='falkor://custom:6379',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_falkordb.from_url.assert_called_once_with('falkor://custom:6379')
|
||||
mock_client.select_graph.assert_called_once_with('customdb')
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_falkordb):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_result = MagicMock()
|
||||
mock_result.result_set = [["record1"]]
|
||||
mock_graph.query.return_value = mock_result
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_falkordb):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["literal result"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/uri_result"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_so_query(self, mock_falkordb):
|
||||
"""Test SO query (subject and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different predicates
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/pred1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/pred2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_s_query(self, mock_falkordb):
|
||||
"""Test S query (subject only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different predicate-object pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/pred1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/pred2", "http://example.com/uri2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_po_query(self, mock_falkordb):
|
||||
"""Test PO query (predicate and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subjects
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_p_query(self, mock_falkordb):
|
||||
"""Test P query (predicate only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subject-object pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/uri2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_o_query(self, mock_falkordb):
|
||||
"""Test O query (object only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results with different subject-predicate pairs
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/subj1", "http://example.com/pred1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/subj2", "http://example.com/pred2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_falkordb):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.result_set = [["http://example.com/s1", "http://example.com/p1", "literal1"]]
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.result_set = [["http://example.com/s2", "http://example.com/p2", "http://example.com/o2"]]
|
||||
|
||||
mock_graph.query.side_effect = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_falkordb):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_graph = MagicMock()
|
||||
mock_falkordb.from_url.return_value = mock_client
|
||||
mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
# Mock query to raise exception
|
||||
mock_graph.query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_url')
|
||||
assert args.graph_url == 'falkor://falkordb:6379'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'falkordb'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-url', 'falkor://custom:6379',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_url == 'falkor://custom:6379'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.falkordb.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'falkor://short:6379'])
|
||||
|
||||
assert args.graph_url == 'falkor://short:6379'
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.falkordb.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for FalkorDB.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
568
tests/unit/test_query/test_triples_memgraph_query.py
Normal file
568
tests/unit/test_query/test_triples_memgraph_query.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""
|
||||
Tests for Memgraph triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestMemgraphQueryProcessor:
|
||||
"""Test cases for Memgraph query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.memgraph.service.GraphDatabase'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-memgraph-query',
|
||||
graph_host='bolt://localhost:7687'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'memgraph'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://memgraph:7687',
|
||||
auth=('memgraph', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='queryuser',
|
||||
password='querypass',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('queryuser', 'querypass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_graph_db):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_records = [MagicMock()]
|
||||
mock_driver.execute_query.return_value = (mock_records, None, None)
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_graph_db):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"dest": "literal result"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_so_query(self, mock_graph_db):
|
||||
"""Test SO query (subject and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different predicates
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"rel": "http://example.com/pred1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"rel": "http://example.com/pred2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_s_query(self, mock_graph_db):
|
||||
"""Test S query (subject only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different predicate-object pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"rel": "http://example.com/pred1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"rel": "http://example.com/pred2", "dest": "http://example.com/uri2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_po_query(self, mock_graph_db):
|
||||
"""Test PO query (predicate and object specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subjects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_p_query(self, mock_graph_db):
|
||||
"""Test P query (predicate only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subject-object pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2", "dest": "http://example.com/uri2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_o_query(self, mock_graph_db):
|
||||
"""Test O query (object only)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different subject-predicate pairs
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/subj1", "rel": "http://example.com/pred1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/subj2", "rel": "http://example.com/pred2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_graph_db):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_graph_db):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock execute_query to raise exception
|
||||
mock_driver.execute_query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'memgraph'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'memgraph'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'queryuser',
|
||||
'--password', 'querypass',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'queryuser'
|
||||
assert args.password == 'querypass'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.memgraph.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.memgraph.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for memgraph.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
338
tests/unit/test_query/test_triples_neo4j_query.py
Normal file
338
tests/unit/test_query/test_triples_neo4j_query.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""
|
||||
Tests for Neo4j triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
|
||||
|
||||
class TestNeo4jQueryProcessor:
|
||||
"""Test cases for Neo4j query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
with patch('trustgraph.query.triples.neo4j.service.GraphDatabase'):
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-neo4j-query',
|
||||
graph_host='bolt://localhost:7687'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.db == 'neo4j'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://neo4j:7687',
|
||||
auth=('neo4j', 'password')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_custom_params(self, mock_graph_db):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='bolt://custom:7687',
|
||||
username='queryuser',
|
||||
password='querypass',
|
||||
database='customdb'
|
||||
)
|
||||
|
||||
assert processor.db == 'customdb'
|
||||
mock_graph_db.driver.assert_called_once_with(
|
||||
'bolt://custom:7687',
|
||||
auth=('queryuser', 'querypass')
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_spo_query(self, mock_graph_db):
|
||||
"""Test SPO query (all values specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results - both queries return one record each
|
||||
mock_records = [MagicMock()]
|
||||
mock_driver.execute_query.return_value = (mock_records, None, None)
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_sp_query(self, mock_graph_db):
|
||||
"""Test SP query (subject and predicate specified)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results with different objects
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"dest": "literal result"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"dest": "http://example.com/uri_result"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_wildcard_query(self, mock_graph_db):
|
||||
"""Test wildcard query (no constraints)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock query results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {"src": "http://example.com/s1", "rel": "http://example.com/p1", "dest": "literal1"}
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {"src": "http://example.com/s2", "rel": "http://example.com/p2", "dest": "http://example.com/o2"}
|
||||
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], None, None), # Literal query
|
||||
([mock_record2], None, None) # URI query
|
||||
]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_exception_handling(self, mock_graph_db):
|
||||
"""Test exception handling during query processing"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
# Mock execute_query to raise exception
|
||||
mock_driver.execute_query.side_effect = Exception("Database connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert hasattr(args, 'username')
|
||||
assert args.username == 'neo4j'
|
||||
assert hasattr(args, 'password')
|
||||
assert args.password == 'password'
|
||||
assert hasattr(args, 'database')
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'bolt://custom:7687',
|
||||
'--username', 'queryuser',
|
||||
'--password', 'querypass',
|
||||
'--database', 'querydb'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'bolt://custom:7687'
|
||||
assert args.username == 'queryuser'
|
||||
assert args.password == 'querypass'
|
||||
assert args.database == 'querydb'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.neo4j.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'bolt://short:7687'])
|
||||
|
||||
assert args.graph_host == 'bolt://short:7687'
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.neo4j.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nTriples query service for neo4j.\nInput is a (s, p, o) triple, some values may be null. Output is a list of\ntriples.\n"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue