mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Merge 2.0 to master (#651)
This commit is contained in:
parent
3666ece2c5
commit
b9d7bf9a8b
212 changed files with 13940 additions and 6180 deletions
|
|
@ -88,8 +88,13 @@ async def test_subscriber_deferred_acknowledgment_success():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscriber_deferred_acknowledgment_failure():
|
||||
"""Verify Subscriber negative acks on delivery failure."""
|
||||
async def test_subscriber_dropped_message_still_acks():
|
||||
"""Verify Subscriber acks even when message is dropped (backpressure).
|
||||
|
||||
This prevents redelivery storms on shared topics - if we negative_ack
|
||||
a dropped message, it gets redelivered to all subscribers, none of
|
||||
whom can handle it either, causing a tight redelivery loop.
|
||||
"""
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
|
@ -103,24 +108,66 @@ async def test_subscriber_deferred_acknowledgment_failure():
|
|||
max_size=1, # Very small queue
|
||||
backpressure_strategy="drop_new"
|
||||
)
|
||||
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
|
||||
# Create queue and fill it
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
await queue.put({"existing": "data"})
|
||||
|
||||
# Create mock message - should be dropped
|
||||
msg = create_mock_message("msg-1", {"data": "test"})
|
||||
|
||||
# Process message (should fail due to full queue + drop_new strategy)
|
||||
|
||||
# Create mock message - should be dropped due to full queue
|
||||
msg = create_mock_message("test-queue", {"data": "test"})
|
||||
|
||||
# Process message (should be dropped due to full queue + drop_new strategy)
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
# Should negative acknowledge failed delivery
|
||||
mock_consumer.negative_acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.acknowledge.assert_not_called()
|
||||
|
||||
|
||||
# Should acknowledge even though delivery failed - prevents redelivery storm
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
# Clean up
|
||||
await subscriber.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscriber_orphaned_message_acks():
|
||||
"""Verify Subscriber acks orphaned messages (no matching waiter).
|
||||
|
||||
On shared response topics, if a message arrives for a waiter that
|
||||
no longer exists (e.g., client disconnected, request timed out),
|
||||
we must acknowledge it to prevent redelivery storms.
|
||||
"""
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
# Don't create any queues - message will be orphaned
|
||||
# This simulates a response arriving after the waiter has unsubscribed
|
||||
|
||||
# Create mock message with an ID that has no matching waiter
|
||||
msg = create_mock_message("non-existent-waiter-id", {"data": "orphaned"})
|
||||
|
||||
# Process message (should be orphaned - no matching waiter)
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
# Should acknowledge orphaned message - prevents redelivery storm
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
# Clean up
|
||||
await subscriber.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection="sales_data",
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -92,6 +95,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection=None, # No collection specified
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -132,6 +138,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection='sales_data',
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -201,6 +210,144 @@ class TestSetToolStructuredQuery:
|
|||
assert 'Exception:' in printed_output or 'invalid choice:' in printed_output.lower()
|
||||
|
||||
|
||||
class TestSetToolRowEmbeddingsQuery:
|
||||
"""Test the set_tool function with row-embeddings-query type."""
|
||||
|
||||
@patch('trustgraph.cli.set_tool.Api')
|
||||
def test_set_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
|
||||
"""Test setting a row-embeddings-query tool with all parameters."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = []
|
||||
|
||||
set_tool(
|
||||
url="http://test.com",
|
||||
id="customer_search",
|
||||
name="find_customer",
|
||||
description="Find customers by name using semantic search",
|
||||
type="row-embeddings-query",
|
||||
mcp_tool=None,
|
||||
collection="sales",
|
||||
template=None,
|
||||
schema_name="customers",
|
||||
index_name="full_name",
|
||||
limit=20,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Tool set." in captured.out
|
||||
|
||||
# Verify the tool was stored correctly
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
config_value = call_args[0]
|
||||
assert config_value.type == "tool"
|
||||
assert config_value.key == "customer_search"
|
||||
|
||||
stored_tool = json.loads(config_value.value)
|
||||
assert stored_tool["name"] == "find_customer"
|
||||
assert stored_tool["type"] == "row-embeddings-query"
|
||||
assert stored_tool["collection"] == "sales"
|
||||
assert stored_tool["schema-name"] == "customers"
|
||||
assert stored_tool["index-name"] == "full_name"
|
||||
assert stored_tool["limit"] == 20
|
||||
|
||||
@patch('trustgraph.cli.set_tool.Api')
|
||||
def test_set_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
|
||||
"""Test setting row-embeddings-query tool with minimal parameters."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = []
|
||||
|
||||
set_tool(
|
||||
url="http://test.com",
|
||||
id="product_search",
|
||||
name="find_product",
|
||||
description="Find products using semantic search",
|
||||
type="row-embeddings-query",
|
||||
mcp_tool=None,
|
||||
collection=None,
|
||||
template=None,
|
||||
schema_name="products",
|
||||
index_name=None, # No index filter
|
||||
limit=None, # Use default
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Tool set." in captured.out
|
||||
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
stored_tool = json.loads(call_args[0].value)
|
||||
assert stored_tool["type"] == "row-embeddings-query"
|
||||
assert stored_tool["schema-name"] == "products"
|
||||
assert "index-name" not in stored_tool # Should not be included if None
|
||||
assert "limit" not in stored_tool # Should not be included if None
|
||||
assert "collection" not in stored_tool # Should not be included if None
|
||||
|
||||
def test_set_main_row_embeddings_query_with_all_options(self):
|
||||
"""Test set main() with row-embeddings-query tool type and all options."""
|
||||
test_args = [
|
||||
'tg-set-tool',
|
||||
'--id', 'customer_search',
|
||||
'--name', 'find_customer',
|
||||
'--type', 'row-embeddings-query',
|
||||
'--description', 'Find customers by name',
|
||||
'--schema-name', 'customers',
|
||||
'--collection', 'sales',
|
||||
'--index-name', 'full_name',
|
||||
'--limit', '25',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
|
||||
|
||||
set_main()
|
||||
|
||||
mock_set.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
id='customer_search',
|
||||
name='find_customer',
|
||||
description='Find customers by name',
|
||||
type='row-embeddings-query',
|
||||
mcp_tool=None,
|
||||
collection='sales',
|
||||
template=None,
|
||||
schema_name='customers',
|
||||
index_name='full_name',
|
||||
limit=25,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
"""Test that 'row-embeddings-query' is included in valid tool types."""
|
||||
test_args = [
|
||||
'tg-set-tool',
|
||||
'--id', 'test_tool',
|
||||
'--name', 'test_tool',
|
||||
'--type', 'row-embeddings-query',
|
||||
'--description', 'Test tool',
|
||||
'--schema-name', 'test_schema'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
|
||||
|
||||
# Should not raise an exception about invalid type
|
||||
set_main()
|
||||
mock_set.assert_called_once()
|
||||
|
||||
|
||||
class TestShowToolsStructuredQuery:
|
||||
"""Test the show_tools function with structured-query tools."""
|
||||
|
||||
|
|
@ -259,9 +406,9 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_mixed_tool_types(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying multiple tool types including structured-query."""
|
||||
"""Test displaying multiple tool types including structured-query and row-embeddings-query."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": "ask_knowledge",
|
||||
|
|
@ -270,37 +417,47 @@ class TestShowToolsStructuredQuery:
|
|||
"collection": "docs"
|
||||
},
|
||||
{
|
||||
"name": "query_data",
|
||||
"name": "query_data",
|
||||
"description": "Query structured data",
|
||||
"type": "structured-query",
|
||||
"collection": "sales"
|
||||
},
|
||||
{
|
||||
"name": "find_customer",
|
||||
"description": "Find customers by semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"schema-name": "customers",
|
||||
"collection": "crm"
|
||||
},
|
||||
{
|
||||
"name": "complete_text",
|
||||
"description": "Generate text",
|
||||
"type": "text-completion"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
config_values = [
|
||||
ConfigValue(type="tool", key=f"tool_{i}", value=json.dumps(tool))
|
||||
for i, tool in enumerate(tools)
|
||||
]
|
||||
mock_config.get_values.return_value = config_values
|
||||
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
|
||||
# All tool types should be displayed
|
||||
assert "knowledge-query" in output
|
||||
assert "structured-query" in output
|
||||
assert "structured-query" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "text-completion" in output
|
||||
|
||||
|
||||
# Collections should be shown for appropriate tools
|
||||
assert "docs" in output # knowledge-query collection
|
||||
assert "sales" in output # structured-query collection
|
||||
assert "crm" in output # row-embeddings-query collection
|
||||
assert "customers" in output # row-embeddings-query schema-name
|
||||
|
||||
def test_show_main_parses_args_correctly(self):
|
||||
"""Test that show main() parses arguments correctly."""
|
||||
|
|
@ -317,6 +474,76 @@ class TestShowToolsStructuredQuery:
|
|||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
"""Test the show_tools function with row-embeddings-query tools."""
|
||||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying a row-embeddings-query tool with all fields."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
tool_config = {
|
||||
"name": "find_customer",
|
||||
"description": "Find customers by name using semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"collection": "sales",
|
||||
"schema-name": "customers",
|
||||
"index-name": "full_name",
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
config_value = ConfigValue(
|
||||
type="tool",
|
||||
key="customer_search",
|
||||
value=json.dumps(tool_config)
|
||||
)
|
||||
mock_config.get_values.return_value = [config_value]
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Check that tool information is displayed
|
||||
assert "customer_search" in output
|
||||
assert "find_customer" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "sales" in output # Collection
|
||||
assert "customers" in output # Schema name
|
||||
assert "full_name" in output # Index name
|
||||
assert "20" in output # Limit
|
||||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying row-embeddings-query tool with minimal fields."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
tool_config = {
|
||||
"name": "find_product",
|
||||
"description": "Find products using semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"schema-name": "products"
|
||||
# No collection, index-name, or limit
|
||||
}
|
||||
|
||||
config_value = ConfigValue(
|
||||
type="tool",
|
||||
key="product_search",
|
||||
value=json.dumps(tool_config)
|
||||
)
|
||||
mock_config.get_values.return_value = [config_value]
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Should display the tool with schema-name
|
||||
assert "product_search" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "products" in output # Schema name
|
||||
|
||||
|
||||
class TestStructuredQueryToolValidation:
|
||||
"""Test validation specific to structured-query tools."""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|||
from unittest.mock import call
|
||||
|
||||
from trustgraph.cores.knowledge import KnowledgeManager
|
||||
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Value, EntityEmbeddings
|
||||
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -71,15 +71,15 @@ def sample_triples():
|
|||
return Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/john", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value="John Smith", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/john"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -97,7 +97,7 @@ def sample_graph_embeddings():
|
|||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/john", is_uri=True),
|
||||
entity=Term(type=IRI, iri="http://example.org/john"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
]
|
||||
|
|
|
|||
599
tests/unit/test_direct/test_entity_centric_kg.py
Normal file
599
tests/unit/test_direct/test_entity_centric_kg.py
Normal file
|
|
@ -0,0 +1,599 @@
|
|||
"""
|
||||
Unit tests for EntityCentricKnowledgeGraph class
|
||||
|
||||
Tests the entity-centric knowledge graph implementation without requiring
|
||||
an actual Cassandra connection. Uses mocking to verify correct behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
import os
|
||||
|
||||
|
||||
class TestEntityCentricKnowledgeGraph:
|
||||
"""Test cases for EntityCentricKnowledgeGraph"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cluster(self):
|
||||
"""Create a mock Cassandra cluster"""
|
||||
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cluster_cls:
|
||||
mock_cluster = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster.connect.return_value = mock_session
|
||||
mock_cluster_cls.return_value = mock_cluster
|
||||
yield mock_cluster_cls, mock_cluster, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def entity_kg(self, mock_cluster):
|
||||
"""Create an EntityCentricKnowledgeGraph instance with mocked Cassandra"""
|
||||
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
|
||||
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
|
||||
|
||||
# Create instance
|
||||
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
|
||||
return kg, mock_session
|
||||
|
||||
def test_init_creates_entity_centric_schema(self, mock_cluster):
|
||||
"""Test that initialization creates the 2-table entity-centric schema"""
|
||||
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
|
||||
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
|
||||
|
||||
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
|
||||
|
||||
# Verify schema tables were created
|
||||
execute_calls = mock_session.execute.call_args_list
|
||||
executed_statements = [str(c) for c in execute_calls]
|
||||
|
||||
# Check for keyspace creation
|
||||
keyspace_created = any('create keyspace' in str(c).lower() for c in execute_calls)
|
||||
assert keyspace_created
|
||||
|
||||
# Check for quads_by_entity table
|
||||
entity_table_created = any('quads_by_entity' in str(c) for c in execute_calls)
|
||||
assert entity_table_created
|
||||
|
||||
# Check for quads_by_collection table
|
||||
collection_table_created = any('quads_by_collection' in str(c) for c in execute_calls)
|
||||
assert collection_table_created
|
||||
|
||||
# Check for collection_metadata table
|
||||
metadata_table_created = any('collection_metadata' in str(c) for c in execute_calls)
|
||||
assert metadata_table_created
|
||||
|
||||
def test_prepare_statements_initialized(self, entity_kg):
|
||||
"""Test that prepared statements are initialized"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Verify prepare was called for various statements
|
||||
assert mock_session.prepare.called
|
||||
prepare_calls = mock_session.prepare.call_args_list
|
||||
|
||||
# Check that key prepared statements exist
|
||||
prepared_queries = [str(c) for c in prepare_calls]
|
||||
|
||||
# Insert statements
|
||||
insert_entity_stmt = any('INSERT INTO' in str(c) and 'quads_by_entity' in str(c)
|
||||
for c in prepare_calls)
|
||||
assert insert_entity_stmt
|
||||
|
||||
insert_collection_stmt = any('INSERT INTO' in str(c) and 'quads_by_collection' in str(c)
|
||||
for c in prepare_calls)
|
||||
assert insert_collection_stmt
|
||||
|
||||
def test_insert_uri_object_creates_4_entity_rows(self, entity_kg):
|
||||
"""Test that inserting a quad with URI object creates 4 entity rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Reset mocks to track only insert-related calls
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g='http://example.org/graph1',
|
||||
otype='u'
|
||||
)
|
||||
|
||||
# Verify batch was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_literal_object_creates_3_entity_rows(self, entity_kg):
|
||||
"""Test that inserting a quad with literal object creates 3 entity rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://www.w3.org/2000/01/rdf-schema#label',
|
||||
o='Alice Smith',
|
||||
g=None,
|
||||
otype='l',
|
||||
dtype='xsd:string',
|
||||
lang='en'
|
||||
)
|
||||
|
||||
# Verify batch was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_default_graph(self, entity_kg):
|
||||
"""Test that None graph is stored as empty string"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g=None,
|
||||
otype='u'
|
||||
)
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_auto_detects_otype(self, entity_kg):
|
||||
"""Test that otype is auto-detected when not provided"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
# URI should be auto-detected
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob'
|
||||
)
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
# Literal should be auto-detected
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/name',
|
||||
o='Alice'
|
||||
)
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_get_s_returns_quads_for_subject(self, entity_kg):
|
||||
"""Test get_s queries by subject"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Mock the query result
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', o='http://example.org/Bob',
|
||||
d='', otype='u', dtype='', lang='', s='http://example.org/Alice')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice')
|
||||
|
||||
# Verify query was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
# Results should be QuadResult objects
|
||||
assert len(results) == 1
|
||||
assert results[0].s == 'http://example.org/Alice'
|
||||
assert results[0].p == 'http://example.org/knows'
|
||||
assert results[0].o == 'http://example.org/Bob'
|
||||
|
||||
def test_get_p_returns_quads_for_predicate(self, entity_kg):
|
||||
"""Test get_p queries by predicate"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', o='http://example.org/Bob',
|
||||
d='', otype='u', dtype='', lang='', p='http://example.org/knows')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_p('test_collection', 'http://example.org/knows')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_o_returns_quads_for_object(self, entity_kg):
|
||||
"""Test get_o queries by object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
|
||||
d='', otype='u', dtype='', lang='', o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_o('test_collection', 'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_sp_returns_quads_for_subject_predicate(self, entity_kg):
|
||||
"""Test get_sp queries by subject and predicate"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(o='http://example.org/Bob', d='', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_sp('test_collection', 'http://example.org/Alice',
|
||||
'http://example.org/knows')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_po_returns_quads_for_predicate_object(self, entity_kg):
|
||||
"""Test get_po queries by predicate and object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', d='', otype='u', dtype='', lang='',
|
||||
o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_po('test_collection', 'http://example.org/knows',
|
||||
'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_os_returns_quads_for_object_subject(self, entity_kg):
|
||||
"""Test get_os queries by object and subject"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='', otype='u', dtype='', lang='',
|
||||
s='http://example.org/Alice', o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_os('test_collection', 'http://example.org/Bob',
|
||||
'http://example.org/Alice')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_spo_returns_quads_for_subject_predicate_object(self, entity_kg):
|
||||
"""Test get_spo queries by subject, predicate, and object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(d='', otype='u', dtype='', lang='',
|
||||
o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_spo('test_collection', 'http://example.org/Alice',
|
||||
'http://example.org/knows', 'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_g_returns_quads_for_graph(self, entity_kg):
|
||||
"""Test get_g queries by graph"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_g('test_collection', 'http://example.org/graph1')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_get_all_returns_all_quads_in_collection(self, entity_kg):
|
||||
"""Test get_all returns all quads"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_all('test_collection')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
|
||||
"""Test that g='*' returns quads from all graphs"""
|
||||
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Bob'),
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Charlie')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
|
||||
|
||||
# Should return quads from both graphs
|
||||
assert len(results) == 2
|
||||
|
||||
def test_specific_graph_filters_results(self, entity_kg):
|
||||
"""Test that specifying a graph filters results"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Bob'),
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Charlie')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice',
|
||||
g='http://example.org/graph1')
|
||||
|
||||
# Should only return quads from graph1
|
||||
assert len(results) == 1
|
||||
assert results[0].g == 'http://example.org/graph1'
|
||||
|
||||
def test_collection_exists_returns_true_when_exists(self, entity_kg):
|
||||
"""Test collection_exists returns True for existing collection"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [MagicMock(collection='test_collection')]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
exists = kg.collection_exists('test_collection')
|
||||
|
||||
assert exists is True
|
||||
|
||||
def test_collection_exists_returns_false_when_not_exists(self, entity_kg):
|
||||
"""Test collection_exists returns False for non-existing collection"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.execute.return_value = []
|
||||
|
||||
exists = kg.collection_exists('nonexistent_collection')
|
||||
|
||||
assert exists is False
|
||||
|
||||
def test_create_collection_inserts_metadata(self, entity_kg):
|
||||
"""Test create_collection inserts metadata row"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
kg.create_collection('test_collection')
|
||||
|
||||
# Verify INSERT was executed for collection_metadata
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_delete_collection_removes_all_data(self, entity_kg):
|
||||
"""Test delete_collection removes entity partitions and collection rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Mock reading quads from collection
|
||||
mock_quads = [
|
||||
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u')
|
||||
]
|
||||
mock_session.execute.return_value = mock_quads
|
||||
|
||||
mock_session.reset_mock()
|
||||
kg.delete_collection('test_collection')
|
||||
|
||||
# Verify delete operations were executed
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_close_shuts_down_connections(self, entity_kg):
|
||||
"""Test close shuts down session and cluster"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
kg.close()
|
||||
|
||||
mock_session.shutdown.assert_called_once()
|
||||
kg.cluster.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestQuadResult:
|
||||
"""Test cases for QuadResult class"""
|
||||
|
||||
def test_quad_result_stores_all_fields(self):
|
||||
"""Test QuadResult stores all quad fields"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g='http://example.org/graph1',
|
||||
otype='u',
|
||||
dtype='',
|
||||
lang=''
|
||||
)
|
||||
|
||||
assert result.s == 'http://example.org/Alice'
|
||||
assert result.p == 'http://example.org/knows'
|
||||
assert result.o == 'http://example.org/Bob'
|
||||
assert result.g == 'http://example.org/graph1'
|
||||
assert result.otype == 'u'
|
||||
assert result.dtype == ''
|
||||
assert result.lang == ''
|
||||
|
||||
def test_quad_result_defaults(self):
|
||||
"""Test QuadResult default values"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/s',
|
||||
p='http://example.org/p',
|
||||
o='literal value',
|
||||
g=''
|
||||
)
|
||||
|
||||
assert result.otype == 'u' # Default otype
|
||||
assert result.dtype == ''
|
||||
assert result.lang == ''
|
||||
|
||||
def test_quad_result_with_literal_metadata(self):
|
||||
"""Test QuadResult with literal metadata"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/Alice',
|
||||
p='http://www.w3.org/2000/01/rdf-schema#label',
|
||||
o='Alice Smith',
|
||||
g='',
|
||||
otype='l',
|
||||
dtype='xsd:string',
|
||||
lang='en'
|
||||
)
|
||||
|
||||
assert result.otype == 'l'
|
||||
assert result.dtype == 'xsd:string'
|
||||
assert result.lang == 'en'
|
||||
|
||||
|
||||
class TestWriteHelperFunctions:
|
||||
"""Test cases for helper functions in write.py"""
|
||||
|
||||
def test_get_term_otype_for_iri(self):
|
||||
"""Test get_term_otype returns 'u' for IRI terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_otype(term) == 'u'
|
||||
|
||||
def test_get_term_otype_for_literal(self):
|
||||
"""Test get_term_otype returns 'l' for LITERAL terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='Alice Smith')
|
||||
assert get_term_otype(term) == 'l'
|
||||
|
||||
def test_get_term_otype_for_blank(self):
|
||||
"""Test get_term_otype returns 'u' for BLANK terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, BLANK
|
||||
|
||||
term = Term(type=BLANK, id='_:b1')
|
||||
assert get_term_otype(term) == 'u'
|
||||
|
||||
def test_get_term_otype_for_triple(self):
|
||||
"""Test get_term_otype returns 't' for TRIPLE terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, TRIPLE
|
||||
|
||||
term = Term(type=TRIPLE)
|
||||
assert get_term_otype(term) == 't'
|
||||
|
||||
def test_get_term_otype_for_none(self):
|
||||
"""Test get_term_otype returns 'u' for None"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
|
||||
assert get_term_otype(None) == 'u'
|
||||
|
||||
def test_get_term_dtype_for_literal(self):
|
||||
"""Test get_term_dtype extracts datatype from LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='42', datatype='xsd:integer')
|
||||
assert get_term_dtype(term) == 'xsd:integer'
|
||||
|
||||
def test_get_term_dtype_for_non_literal(self):
|
||||
"""Test get_term_dtype returns empty string for non-LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_dtype(term) == ''
|
||||
|
||||
def test_get_term_dtype_for_none(self):
|
||||
"""Test get_term_dtype returns empty string for None"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
|
||||
assert get_term_dtype(None) == ''
|
||||
|
||||
def test_get_term_lang_for_literal(self):
|
||||
"""Test get_term_lang extracts language from LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_lang
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='Alice Smith', language='en')
|
||||
assert get_term_lang(term) == 'en'
|
||||
|
||||
def test_get_term_lang_for_non_literal(self):
|
||||
"""Test get_term_lang returns empty string for non-LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_lang
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_lang(term) == ''
|
||||
|
||||
|
||||
class TestServiceHelperFunctions:
|
||||
"""Test cases for helper functions in service.py"""
|
||||
|
||||
def test_create_term_with_uri_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='u'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice', otype='u')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_with_literal_otype(self):
|
||||
"""Test create_term creates LITERAL Term for otype='l'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
assert term.datatype == 'xsd:string'
|
||||
assert term.language == 'en'
|
||||
|
||||
def test_create_term_with_triple_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='t'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/statement1', otype='t')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/statement1'
|
||||
|
||||
def test_create_term_heuristic_fallback_uri(self):
|
||||
"""Test create_term uses URL heuristic when otype not provided"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_heuristic_fallback_literal(self):
|
||||
"""Test create_term uses literal heuristic when otype not provided"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
|
||||
Tests the Stage 1 processor that computes embeddings for row index fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test row embeddings processor functionality"""
|
||||
|
||||
async def test_processor_initialization(self):
|
||||
"""Test basic processor initialization"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert hasattr(processor, 'schemas')
|
||||
assert processor.schemas == {}
|
||||
assert processor.batch_size == 10 # default
|
||||
|
||||
async def test_processor_initialization_with_custom_batch_size(self):
|
||||
"""Test processor initialization with custom batch size"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings',
|
||||
'batch_size': 25
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert processor.batch_size == 25
|
||||
|
||||
async def test_get_index_names_single_index(self):
|
||||
"""Test getting index names with single indexed field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
Field(name='email', type='text', indexed=False),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed field
|
||||
assert 'id' in index_names
|
||||
assert 'name' in index_names
|
||||
assert 'email' not in index_names
|
||||
|
||||
async def test_get_index_names_no_indexes(self):
|
||||
"""Test getting index names when no fields are indexed"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='logs',
|
||||
description='Log records',
|
||||
fields=[
|
||||
Field(name='timestamp', type='text'),
|
||||
Field(name='message', type='text'),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert index_names == []
|
||||
|
||||
async def test_build_index_value_single_field(self):
|
||||
"""Test building index value for single field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'id': 'CUST001',
|
||||
'name': 'John Doe',
|
||||
'email': 'john@example.com'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'name')
|
||||
|
||||
assert result == ['John Doe']
|
||||
|
||||
async def test_build_index_value_composite_index(self):
|
||||
"""Test building index value for composite index"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'first_name': 'John',
|
||||
'last_name': 'Doe',
|
||||
'city': 'New York'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'first_name, last_name')
|
||||
|
||||
assert result == ['John', 'Doe']
|
||||
|
||||
async def test_build_index_value_missing_field(self):
|
||||
"""Test building index value when field is missing"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'name': 'John Doe'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'missing_field')
|
||||
|
||||
assert result == ['']
|
||||
|
||||
async def test_build_text_for_embedding_single_value(self):
|
||||
"""Test building text representation for single value"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John Doe'])
|
||||
|
||||
assert result == 'John Doe'
|
||||
|
||||
async def test_build_text_for_embedding_multiple_values(self):
|
||||
"""Test building text representation for multiple values"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
|
||||
|
||||
assert result == 'John Doe NYC'
|
||||
|
||||
async def test_on_schema_config_loads_schemas(self):
|
||||
"""Test that schema configuration is loaded correctly"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema_def = {
|
||||
'name': 'customers',
|
||||
'description': 'Customer records',
|
||||
'fields': [
|
||||
{'name': 'id', 'type': 'text', 'primary_key': True},
|
||||
{'name': 'name', 'type': 'text', 'indexed': True},
|
||||
{'name': 'email', 'type': 'text'}
|
||||
]
|
||||
}
|
||||
|
||||
config_data = {
|
||||
'schema': {
|
||||
'customers': json.dumps(schema_def)
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
config_data = {
|
||||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_drops_unknown_schema(self):
|
||||
"""Test that messages for unknown schemas are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='unknown_schema',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_processes_embeddings(self):
|
||||
"""Test processing a message and computing embeddings"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject, RowSchema, Field
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[
|
||||
{'id': 'CUST001', 'name': 'John Doe'},
|
||||
{'id': 'CUST002', 'name': 'Jane Smith'}
|
||||
]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow_factory(name):
|
||||
if name == 'embeddings-request':
|
||||
return mock_embeddings_request
|
||||
elif name == 'output':
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -7,7 +7,7 @@ collecting labels and definitions for entity embedding and retrieval.
|
|||
|
||||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
from trustgraph.schema.knowledge.graph import EntityContext
|
||||
|
||||
|
||||
|
|
@ -25,9 +25,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that entity context is built from rdfs:label."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Cornish Pasty", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/cornish-pasty"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Cornish Pasty")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -35,16 +35,16 @@ class TestEntityContextBuilding:
|
|||
|
||||
assert len(contexts) == 1, "Should create one entity context"
|
||||
assert isinstance(contexts[0], EntityContext)
|
||||
assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty"
|
||||
assert contexts[0].entity.iri == "https://example.com/entity/cornish-pasty"
|
||||
assert "Label: Cornish Pasty" in contexts[0].context
|
||||
|
||||
def test_builds_context_from_definition(self, processor):
|
||||
"""Test that entity context includes definitions."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/pasty", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="A baked pastry filled with savory ingredients", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/pasty"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="A baked pastry filled with savory ingredients")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -57,14 +57,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that label and definition are combined in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Pasty Recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Pasty Recipe")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Traditional Cornish pastry recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Traditional Cornish pastry recipe")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -80,14 +80,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that only the first label is used in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="First Label", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="First Label")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Second Label", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Second Label")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -101,14 +101,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that all definitions are included in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="First definition", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="First definition")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Second definition", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Second definition")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -123,9 +123,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that schema.org description is treated as definition."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="https://schema.org/description", is_uri=True),
|
||||
o=Value(value="A delicious food item", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="https://schema.org/description"),
|
||||
o=Term(type=LITERAL, value="A delicious food item")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -138,26 +138,26 @@ class TestEntityContextBuilding:
|
|||
"""Test that contexts are created for multiple entities."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity One", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity One")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity2", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity Two", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity Two")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity3", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity Three", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity3"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity Three")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 3, "Should create context for each entity"
|
||||
entity_uris = [ctx.entity.value for ctx in contexts]
|
||||
entity_uris = [ctx.entity.iri for ctx in contexts]
|
||||
assert "https://example.com/entity/entity1" in entity_uris
|
||||
assert "https://example.com/entity/entity2" in entity_uris
|
||||
assert "https://example.com/entity/entity3" in entity_uris
|
||||
|
|
@ -166,9 +166,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that URI objects are ignored (only literal labels/definitions)."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=IRI, iri="https://example.com/some/uri") # URI, not literal
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -181,14 +181,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that other predicates are ignored."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Food", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Food")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://example.com/produces", is_uri=True),
|
||||
o=Value(value="https://example.com/entity/food2", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://example.com/produces"),
|
||||
o=Term(type=IRI, iri="https://example.com/entity/food2")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -205,29 +205,29 @@ class TestEntityContextBuilding:
|
|||
|
||||
assert len(contexts) == 0, "Empty triple list should return empty contexts"
|
||||
|
||||
def test_entity_context_has_value_object(self, processor):
|
||||
"""Test that EntityContext.entity is a Value object."""
|
||||
def test_entity_context_has_term_object(self, processor):
|
||||
"""Test that EntityContext.entity is a Term object."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test Entity", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test Entity")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 1
|
||||
assert isinstance(contexts[0].entity, Value), "Entity should be Value object"
|
||||
assert contexts[0].entity.is_uri, "Entity should be marked as URI"
|
||||
assert isinstance(contexts[0].entity, Term), "Entity should be Term object"
|
||||
assert contexts[0].entity.type == IRI, "Entity should be IRI type"
|
||||
|
||||
def test_entity_context_text_is_string(self, processor):
|
||||
"""Test that EntityContext.context is a string."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test Entity", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test Entity")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -241,22 +241,22 @@ class TestEntityContextBuilding:
|
|||
triples = [
|
||||
# Entity with label - should create context
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity One", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity One")
|
||||
),
|
||||
# Entity with only rdf:type - should NOT create context
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity2", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Food", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Food")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 1, "Should only create context for entity with label/definition"
|
||||
assert contexts[0].entity.value == "https://example.com/entity/entity1"
|
||||
assert contexts[0].entity.iri == "https://example.com/entity/entity1"
|
||||
|
||||
|
||||
class TestEntityContextEdgeCases:
|
||||
|
|
@ -266,9 +266,9 @@ class TestEntityContextEdgeCases:
|
|||
"""Test handling of unicode characters in labels."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/café", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Café Spécial", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/café"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Café Spécial")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -282,9 +282,9 @@ class TestEntityContextEdgeCases:
|
|||
long_def = "This is a very long definition " * 50
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value=long_def, is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value=long_def)
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -297,9 +297,9 @@ class TestEntityContextEdgeCases:
|
|||
"""Test handling of special characters in context text."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test & Entity <with> \"quotes\"", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test & Entity <with> \"quotes\"")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -313,27 +313,27 @@ class TestEntityContextEdgeCases:
|
|||
triples = [
|
||||
# Label - relevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Cornish Pasty Recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Cornish Pasty Recipe")
|
||||
),
|
||||
# Type - irrelevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Recipe", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Recipe")
|
||||
),
|
||||
# Property - irrelevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://example.com/produces", is_uri=True),
|
||||
o=Value(value="https://example.com/entity/pasty", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://example.com/produces"),
|
||||
o=Term(type=IRI, iri="https://example.com/entity/pasty")
|
||||
),
|
||||
# Definition - relevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Traditional British pastry recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Traditional British pastry recipe")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ the knowledge graph.
|
|||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -92,12 +92,12 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for Recipe class
|
||||
recipe_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class"
|
||||
assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \
|
||||
assert recipe_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#Class", \
|
||||
"Class type should be owl:Class"
|
||||
|
||||
def test_generates_class_labels(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -107,14 +107,14 @@ class TestOntologyTripleGeneration:
|
|||
# Find label triples for Recipe class
|
||||
recipe_label_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
|
||||
assert len(recipe_label_triples) == 1, "Should generate label triple for class"
|
||||
assert recipe_label_triples[0].o.value == "Recipe", \
|
||||
"Label should match class label from ontology"
|
||||
assert not recipe_label_triples[0].o.is_uri, \
|
||||
assert recipe_label_triples[0].o.type == LITERAL, \
|
||||
"Label should be a literal, not URI"
|
||||
|
||||
def test_generates_class_comments(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -124,8 +124,8 @@ class TestOntologyTripleGeneration:
|
|||
# Find comment triples for Recipe class
|
||||
recipe_comment_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#comment"
|
||||
]
|
||||
|
||||
assert len(recipe_comment_triples) == 1, "Should generate comment triple for class"
|
||||
|
|
@ -139,13 +139,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for ingredients property
|
||||
ingredients_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(ingredients_type_triples) == 1, \
|
||||
"Should generate exactly one type triple per object property"
|
||||
assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \
|
||||
assert ingredients_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#ObjectProperty", \
|
||||
"Object property type should be owl:ObjectProperty"
|
||||
|
||||
def test_generates_object_property_labels(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -155,8 +155,8 @@ class TestOntologyTripleGeneration:
|
|||
# Find label triples for ingredients property
|
||||
ingredients_label_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
|
||||
assert len(ingredients_label_triples) == 1, \
|
||||
|
|
@ -171,15 +171,15 @@ class TestOntologyTripleGeneration:
|
|||
# Find domain triples for ingredients property
|
||||
ingredients_domain_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#domain"
|
||||
]
|
||||
|
||||
assert len(ingredients_domain_triples) == 1, \
|
||||
"Should generate domain triple for object property"
|
||||
assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \
|
||||
assert ingredients_domain_triples[0].o.iri == "http://purl.org/ontology/fo/Recipe", \
|
||||
"Domain should be Recipe class URI"
|
||||
assert ingredients_domain_triples[0].o.is_uri, \
|
||||
assert ingredients_domain_triples[0].o.type == IRI, \
|
||||
"Domain should be a URI reference"
|
||||
|
||||
def test_generates_object_property_range(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -189,13 +189,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find range triples for produces property
|
||||
produces_range_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/produces"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/produces"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
]
|
||||
|
||||
assert len(produces_range_triples) == 1, \
|
||||
"Should generate range triple for object property"
|
||||
assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \
|
||||
assert produces_range_triples[0].o.iri == "http://purl.org/ontology/fo/Food", \
|
||||
"Range should be Food class URI"
|
||||
|
||||
def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -205,13 +205,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for serves property
|
||||
serves_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(serves_type_triples) == 1, \
|
||||
"Should generate exactly one type triple per datatype property"
|
||||
assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
|
||||
assert serves_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
|
||||
"Datatype property type should be owl:DatatypeProperty"
|
||||
|
||||
def test_generates_datatype_property_range(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -221,13 +221,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find range triples for serves property
|
||||
serves_range_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
]
|
||||
|
||||
assert len(serves_range_triples) == 1, \
|
||||
"Should generate range triple for datatype property"
|
||||
assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \
|
||||
assert serves_range_triples[0].o.iri == "http://www.w3.org/2001/XMLSchema#string", \
|
||||
"Range should be XSD type URI (xsd:string expanded)"
|
||||
|
||||
def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -236,9 +236,9 @@ class TestOntologyTripleGeneration:
|
|||
|
||||
# Count unique class subjects
|
||||
class_subjects = set(
|
||||
t.s.value for t in triples
|
||||
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and t.o.value == "http://www.w3.org/2002/07/owl#Class"
|
||||
t.s.iri for t in triples
|
||||
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and t.o.iri == "http://www.w3.org/2002/07/owl#Class"
|
||||
)
|
||||
|
||||
assert len(class_subjects) == 3, \
|
||||
|
|
@ -250,9 +250,9 @@ class TestOntologyTripleGeneration:
|
|||
|
||||
# Count unique property subjects (object + datatype properties)
|
||||
property_subjects = set(
|
||||
t.s.value for t in triples
|
||||
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value)
|
||||
t.s.iri for t in triples
|
||||
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and ("ObjectProperty" in t.o.iri or "DatatypeProperty" in t.o.iri)
|
||||
)
|
||||
|
||||
assert len(property_subjects) == 3, \
|
||||
|
|
@ -276,7 +276,7 @@ class TestOntologyTripleGeneration:
|
|||
# Should still generate proper RDF triples despite dict field names
|
||||
label_triples = [
|
||||
t for t in triples
|
||||
if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
assert len(label_triples) > 0, \
|
||||
"Should generate rdfs:label triples from dict 'labels' field"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ and extracts/validates triples from LLM responses.
|
|||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -248,9 +248,9 @@ class TestTripleParsing:
|
|||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert len(validated) == 1, "Should parse one valid triple"
|
||||
assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty"
|
||||
assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].s.iri == "https://trustgraph.ai/food/cornish-pasty"
|
||||
assert validated[0].p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
|
||||
def test_parse_multiple_triples(self, extractor, sample_ontology_subset):
|
||||
"""Test parsing multiple triples."""
|
||||
|
|
@ -307,11 +307,11 @@ class TestTripleParsing:
|
|||
|
||||
assert len(validated) == 1
|
||||
# Subject should be expanded to entity URI
|
||||
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
|
||||
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
|
||||
# Predicate should be expanded to ontology URI
|
||||
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
|
||||
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
|
||||
# Object should be expanded to class URI
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Food"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Food"
|
||||
|
||||
def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset):
|
||||
"""Test that Triple objects are properly created."""
|
||||
|
|
@ -324,12 +324,12 @@ class TestTripleParsing:
|
|||
assert len(validated) == 1
|
||||
triple = validated[0]
|
||||
assert isinstance(triple, Triple), "Should create Triple objects"
|
||||
assert isinstance(triple.s, Value), "Subject should be Value object"
|
||||
assert isinstance(triple.p, Value), "Predicate should be Value object"
|
||||
assert isinstance(triple.o, Value), "Object should be Value object"
|
||||
assert triple.s.is_uri, "Subject should be marked as URI"
|
||||
assert triple.p.is_uri, "Predicate should be marked as URI"
|
||||
assert not triple.o.is_uri, "Object literal should not be marked as URI"
|
||||
assert isinstance(triple.s, Term), "Subject should be Term object"
|
||||
assert isinstance(triple.p, Term), "Predicate should be Term object"
|
||||
assert isinstance(triple.o, Term), "Object should be Term object"
|
||||
assert triple.s.type == IRI, "Subject should be IRI type"
|
||||
assert triple.p.type == IRI, "Predicate should be IRI type"
|
||||
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
|
||||
|
||||
|
||||
class TestURIExpansionInExtraction:
|
||||
|
|
@ -343,8 +343,8 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].o.is_uri, "Class reference should be URI"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].o.type == IRI, "Class reference should be URI"
|
||||
|
||||
def test_expands_property_names(self, extractor, sample_ontology_subset):
|
||||
"""Test that property names are expanded to full URIs."""
|
||||
|
|
@ -354,7 +354,7 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
|
||||
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
|
||||
|
||||
def test_expands_entity_instances(self, extractor, sample_ontology_subset):
|
||||
"""Test that entity instances get constructed URIs."""
|
||||
|
|
@ -364,8 +364,8 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
|
||||
assert "my-special-recipe" in validated[0].s.value
|
||||
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
|
||||
assert "my-special-recipe" in validated[0].s.iri
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestDispatchSerialize:
|
||||
|
|
@ -14,55 +14,55 @@ class TestDispatchSerialize:
|
|||
|
||||
def test_to_value_with_uri(self):
|
||||
"""Test to_value function with URI"""
|
||||
input_data = {"v": "http://example.com/resource", "e": True}
|
||||
|
||||
input_data = {"t": "i", "i": "http://example.com/resource"}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_to_value_with_literal(self):
|
||||
"""Test to_value function with literal value"""
|
||||
input_data = {"v": "literal string", "e": False}
|
||||
|
||||
input_data = {"t": "l", "v": "literal string"}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_to_subgraph_with_multiple_triples(self):
|
||||
"""Test to_subgraph function with multiple triples"""
|
||||
input_data = [
|
||||
{
|
||||
"s": {"v": "subject1", "e": True},
|
||||
"p": {"v": "predicate1", "e": True},
|
||||
"o": {"v": "object1", "e": False}
|
||||
"s": {"t": "i", "i": "subject1"},
|
||||
"p": {"t": "i", "i": "predicate1"},
|
||||
"o": {"t": "l", "v": "object1"}
|
||||
},
|
||||
{
|
||||
"s": {"v": "subject2", "e": False},
|
||||
"p": {"v": "predicate2", "e": True},
|
||||
"o": {"v": "object2", "e": True}
|
||||
"s": {"t": "l", "v": "subject2"},
|
||||
"p": {"t": "i", "i": "predicate2"},
|
||||
"o": {"t": "i", "i": "object2"}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(triple, Triple) for triple in result)
|
||||
|
||||
|
||||
# Check first triple
|
||||
assert result[0].s.value == "subject1"
|
||||
assert result[0].s.is_uri is True
|
||||
assert result[0].p.value == "predicate1"
|
||||
assert result[0].p.is_uri is True
|
||||
assert result[0].s.iri == "subject1"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "predicate1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "object1"
|
||||
assert result[0].o.is_uri is False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Check second triple
|
||||
assert result[1].s.value == "subject2"
|
||||
assert result[1].s.is_uri is False
|
||||
assert result[1].s.type == LITERAL
|
||||
|
||||
def test_to_subgraph_with_empty_list(self):
|
||||
"""Test to_subgraph function with empty input"""
|
||||
|
|
@ -74,16 +74,16 @@ class TestDispatchSerialize:
|
|||
|
||||
def test_serialize_value_with_uri(self):
|
||||
"""Test serialize_value function with URI value"""
|
||||
value = Value(value="http://example.com/test", is_uri=True)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "http://example.com/test", "e": True}
|
||||
term = Term(type=IRI, iri="http://example.com/test")
|
||||
|
||||
result = serialize_value(term)
|
||||
|
||||
assert result == {"t": "i", "i": "http://example.com/test"}
|
||||
|
||||
def test_serialize_value_with_literal(self):
|
||||
"""Test serialize_value function with literal value"""
|
||||
value = Value(value="test literal", is_uri=False)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "test literal", "e": False}
|
||||
term = Term(type=LITERAL, value="test literal")
|
||||
|
||||
result = serialize_value(term)
|
||||
|
||||
assert result == {"t": "l", "v": "test literal"}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Unit tests for objects import dispatcher.
|
||||
Unit tests for rows import dispatcher.
|
||||
|
||||
Tests the business logic of objects import dispatcher
|
||||
Tests the business logic of rows import dispatcher
|
||||
while mocking the Publisher and websocket components.
|
||||
"""
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
|
||||
from trustgraph.gateway.dispatch.rows_import import RowsImport
|
||||
from trustgraph.schema import Metadata, ExtractedObject
|
||||
|
||||
|
||||
|
|
@ -92,16 +92,16 @@ def minimal_objects_message():
|
|||
}
|
||||
|
||||
|
||||
class TestObjectsImportInitialization:
|
||||
"""Test ObjectsImport initialization."""
|
||||
class TestRowsImportInitialization:
|
||||
"""Test RowsImport initialization."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||
"""Test that RowsImport creates Publisher with correct parameters."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
|
|||
)
|
||||
|
||||
# Verify instance variables are set correctly
|
||||
assert objects_import.ws == mock_websocket
|
||||
assert objects_import.running == mock_running
|
||||
assert objects_import.publisher == mock_publisher_instance
|
||||
assert rows_import.ws == mock_websocket
|
||||
assert rows_import.running == mock_running
|
||||
assert rows_import.publisher == mock_publisher_instance
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport stores all required references."""
|
||||
objects_import = ObjectsImport(
|
||||
"""Test that RowsImport stores all required references."""
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="objects-queue"
|
||||
)
|
||||
|
||||
assert objects_import.ws is mock_websocket
|
||||
assert objects_import.running is mock_running
|
||||
assert rows_import.ws is mock_websocket
|
||||
assert rows_import.running is mock_running
|
||||
|
||||
|
||||
class TestObjectsImportLifecycle:
|
||||
"""Test ObjectsImport lifecycle methods."""
|
||||
class TestRowsImportLifecycle:
|
||||
"""Test RowsImport lifecycle methods."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that start() calls publisher.start()."""
|
||||
|
|
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.start()
|
||||
await rows_import.start()
|
||||
|
||||
mock_publisher_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||
|
|
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
# Verify sequence of operations
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that destroy() handles None websocket gracefully."""
|
||||
|
|
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestObjectsImportMessageProcessing:
|
||||
"""Test ObjectsImport message processing."""
|
||||
class TestRowsImportMessageProcessing:
|
||||
"""Test RowsImport message processing."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() processes complete message correctly."""
|
||||
|
|
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.metadata.collection == "testcollection"
|
||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||
"""Test that receive() handles message with minimal required fields."""
|
||||
|
|
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = minimal_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() uses appropriate default values for optional fields."""
|
||||
|
|
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = message_data
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Get the sent object and verify defaults
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
|
|
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == ""
|
||||
|
||||
|
||||
class TestObjectsImportRunMethod:
|
||||
"""Test ObjectsImport run method."""
|
||||
class TestRowsImportRunMethod:
|
||||
"""Test RowsImport run method."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that run() loops while running.get() returns True."""
|
||||
|
|
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
|
|||
# Set up running state to return True twice, then False
|
||||
mock_running.get.side_effect = [True, True, False]
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify sleep was called twice (for the two True iterations)
|
||||
assert mock_sleep.call_count == 2
|
||||
|
|
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
|
|||
mock_websocket.close.assert_called_once()
|
||||
|
||||
# Verify websocket was set to None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that run() handles None websocket gracefully."""
|
||||
|
|
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
|
|||
|
||||
mock_running.get.return_value = False # Exit immediately
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify websocket remains None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
|
||||
class TestObjectsImportBatchProcessing:
|
||||
"""Test ObjectsImport batch processing functionality."""
|
||||
class TestRowsImportBatchProcessing:
|
||||
"""Test RowsImport batch processing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_objects_message(self):
|
||||
|
|
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
|
|||
"source_span": "Multiple people found in document"
|
||||
}
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
|
|
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = batch_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
|
|||
assert sent_object.confidence == 0.85
|
||||
assert sent_object.source_span == "Multiple people found in document"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
|
|
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_batch_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Should still send the message
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
|
|||
assert len(sent_object.values) == 0
|
||||
|
||||
|
||||
class TestObjectsImportErrorHandling:
|
||||
"""Test error handling in ObjectsImport."""
|
||||
class TestRowsImportErrorHandling:
|
||||
"""Test error handling in RowsImport."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() propagates publisher send errors."""
|
||||
|
|
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
|
|||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
with pytest.raises(Exception, match="Publisher error"):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles malformed JSON appropriately."""
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
|
@ -6,11 +6,21 @@ import pytest
|
|||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
# Mock schema classes for testing
|
||||
class Value:
|
||||
def __init__(self, value, is_uri, type):
|
||||
self.value = value
|
||||
self.is_uri = is_uri
|
||||
# Term type constants
|
||||
IRI = "i"
|
||||
LITERAL = "l"
|
||||
BLANK = "b"
|
||||
TRIPLE = "t"
|
||||
|
||||
class Term:
|
||||
def __init__(self, type, iri=None, value=None, id=None, datatype=None, language=None, triple=None):
|
||||
self.type = type
|
||||
self.iri = iri
|
||||
self.value = value
|
||||
self.id = id
|
||||
self.datatype = datatype
|
||||
self.language = language
|
||||
self.triple = triple
|
||||
|
||||
class Triple:
|
||||
def __init__(self, s, p, o):
|
||||
|
|
@ -66,32 +76,30 @@ def sample_relationships():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_value_uri():
|
||||
"""Sample URI Value object"""
|
||||
return Value(
|
||||
value="http://example.com/person/john-smith",
|
||||
is_uri=True,
|
||||
type=""
|
||||
def sample_term_uri():
|
||||
"""Sample URI Term object"""
|
||||
return Term(
|
||||
type=IRI,
|
||||
iri="http://example.com/person/john-smith"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_value_literal():
|
||||
"""Sample literal Value object"""
|
||||
return Value(
|
||||
value="John Smith",
|
||||
is_uri=False,
|
||||
type="string"
|
||||
def sample_term_literal():
|
||||
"""Sample literal Term object"""
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value="John Smith"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_triple(sample_value_uri, sample_value_literal):
|
||||
def sample_triple(sample_term_uri, sample_term_literal):
|
||||
"""Sample Triple object"""
|
||||
return Triple(
|
||||
s=sample_value_uri,
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=sample_value_literal
|
||||
s=sample_term_uri,
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=sample_term_literal
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
|
@ -33,7 +33,7 @@ class TestAgentKgExtractor:
|
|||
|
||||
# Set up the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
|
@ -53,48 +53,49 @@ class TestAgentKgExtractor:
|
|||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_extraction_data(self):
|
||||
"""Sample extraction data in expected format"""
|
||||
return {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
}
|
||||
"""Sample extraction data in JSONL format (list with type discriminators)"""
|
||||
return [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
def test_to_uri_conversion(self, agent_extractor):
|
||||
"""Test URI conversion for entities"""
|
||||
|
|
@ -113,148 +114,147 @@ class TestAgentKgExtractor:
|
|||
expected = f"{TRUSTGRAPH_ENTITIES}"
|
||||
assert uri == expected
|
||||
|
||||
def test_parse_json_with_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing from code blocks"""
|
||||
# Test JSON in code blocks
|
||||
def test_parse_jsonl_with_code_blocks(self, agent_extractor):
|
||||
"""Test JSONL parsing from code blocks"""
|
||||
# Test JSONL in code blocks - note: JSON uses lowercase true/false
|
||||
response = '''```json
|
||||
{
|
||||
"definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}],
|
||||
"relationships": []
|
||||
}
|
||||
```'''
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert result["definitions"][0]["entity"] == "AI"
|
||||
assert result["definitions"][0]["definition"] == "Artificial Intelligence"
|
||||
assert result["relationships"] == []
|
||||
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}
|
||||
{"type": "relationship", "subject": "AI", "predicate": "is", "object": "technology", "object-entity": false}
|
||||
```'''
|
||||
|
||||
def test_parse_json_without_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing without code blocks"""
|
||||
response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}'''
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert result["definitions"][0]["entity"] == "ML"
|
||||
assert result["definitions"][0]["definition"] == "Machine Learning"
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
def test_parse_json_invalid_format(self, agent_extractor):
|
||||
"""Test JSON parsing with invalid format"""
|
||||
invalid_response = "This is not JSON at all"
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
agent_extractor.parse_json(invalid_response)
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "AI"
|
||||
assert result[0]["definition"] == "Artificial Intelligence"
|
||||
assert result[1]["type"] == "relationship"
|
||||
|
||||
def test_parse_json_malformed_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing with malformed code blocks"""
|
||||
# Missing closing backticks
|
||||
response = '''```json
|
||||
{"definitions": [], "relationships": []}
|
||||
'''
|
||||
|
||||
# Should still parse the JSON content
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
agent_extractor.parse_json(response)
|
||||
def test_parse_jsonl_without_code_blocks(self, agent_extractor):
|
||||
"""Test JSONL parsing without code blocks"""
|
||||
response = '''{"type": "definition", "entity": "ML", "definition": "Machine Learning"}
|
||||
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "ML"
|
||||
assert result[1]["entity"] == "AI"
|
||||
|
||||
def test_parse_jsonl_invalid_lines_skipped(self, agent_extractor):
|
||||
"""Test JSONL parsing skips invalid lines gracefully"""
|
||||
response = '''{"type": "definition", "entity": "Valid", "definition": "Valid def"}
|
||||
This is not JSON at all
|
||||
{"type": "definition", "entity": "Also Valid", "definition": "Another def"}'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 2 valid objects, skipping the invalid line
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "Valid"
|
||||
assert result[1]["entity"] == "Also Valid"
|
||||
|
||||
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
|
||||
"""Test JSONL parsing handles truncated responses"""
|
||||
# Simulates output cut off mid-line
|
||||
response = '''{"type": "definition", "entity": "Complete", "definition": "Full def"}
|
||||
{"type": "definition", "entity": "Trunca'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 1 valid object, the truncated line is skipped
|
||||
assert len(result) == 1
|
||||
assert result[0]["entity"] == "Complete"
|
||||
|
||||
def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of definition data"""
|
||||
data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of AI that enables learning from data."
|
||||
}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of AI that enables learning from data."
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check entity label triple
|
||||
label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
||||
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
||||
assert label_triple is not None
|
||||
assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert label_triple.s.is_uri == True
|
||||
assert label_triple.o.is_uri == False
|
||||
|
||||
assert label_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert label_triple.s.type == IRI
|
||||
assert label_triple.o.type == LITERAL
|
||||
|
||||
# Check definition triple
|
||||
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
|
||||
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
|
||||
assert def_triple is not None
|
||||
assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert def_triple.o.value == "A subset of AI that enables learning from data."
|
||||
|
||||
|
||||
# Check subject-of triple
|
||||
subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None)
|
||||
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
|
||||
assert subject_of_triple is not None
|
||||
assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert subject_of_triple.o.value == "doc123"
|
||||
|
||||
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert subject_of_triple.o.iri == "doc123"
|
||||
|
||||
# Check entity context
|
||||
assert len(entity_contexts) == 1
|
||||
assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert entity_contexts[0].context == "A subset of AI that enables learning from data."
|
||||
|
||||
def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of relationship data"""
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check that subject, predicate, and object labels are created
|
||||
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of"
|
||||
|
||||
|
||||
# Find label triples
|
||||
subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None)
|
||||
subject_label = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == RDF_LABEL), None)
|
||||
assert subject_label is not None
|
||||
assert subject_label.o.value == "Machine Learning"
|
||||
|
||||
predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None)
|
||||
|
||||
predicate_label = next((t for t in triples if t.s.iri == predicate_uri and t.p.iri == RDF_LABEL), None)
|
||||
assert predicate_label is not None
|
||||
assert predicate_label.o.value == "is_subset_of"
|
||||
|
||||
# Check main relationship triple
|
||||
# NOTE: Current implementation has bugs:
|
||||
# 1. Uses data.get("object-entity") instead of rel.get("object-entity")
|
||||
# 2. Sets object_value to predicate_uri instead of actual object URI
|
||||
# This test documents the current buggy behavior
|
||||
rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None)
|
||||
|
||||
# Check main relationship triple
|
||||
object_uri = f"{TRUSTGRAPH_ENTITIES}Artificial%20Intelligence"
|
||||
rel_triple = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == predicate_uri), None)
|
||||
assert rel_triple is not None
|
||||
# Due to bug, object value is set to predicate_uri
|
||||
assert rel_triple.o.value == predicate_uri
|
||||
|
||||
assert rel_triple.o.iri == object_uri
|
||||
assert rel_triple.o.type == IRI
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
|
||||
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
|
||||
|
||||
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of relationships with literal objects"""
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
|
||||
# Check that object labels are not created for literal objects
|
||||
object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"]
|
||||
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
|
||||
# Based on the code logic, it should not create object labels for non-entity objects
|
||||
# But there might be a bug in the original implementation
|
||||
|
||||
|
|
@ -263,75 +263,62 @@ class TestAgentKgExtractor:
|
|||
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
|
||||
|
||||
# Check that we have both definition and relationship triples
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
|
||||
assert len(definition_triples) == 2 # Two definitions
|
||||
|
||||
|
||||
# Check entity contexts are created for definitions
|
||||
assert len(entity_contexts) == 2
|
||||
entity_uris = [ec.entity.value for ec in entity_contexts]
|
||||
entity_uris = [ec.entity.iri for ec in entity_contexts]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
|
||||
"""Test processing when metadata has no ID"""
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test Entity", "definition": "Test definition"}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should not create subject-of relationships when no metadata ID
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
|
||||
|
||||
# Should still create entity contexts
|
||||
assert len(entity_contexts) == 1
|
||||
|
||||
def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of empty extraction data"""
|
||||
data = {"definitions": [], "relationships": []}
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should only have metadata triples
|
||||
assert len(entity_contexts) == 0
|
||||
# Triples should only contain metadata triples if any
|
||||
data = []
|
||||
|
||||
def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with missing keys"""
|
||||
# Test missing definitions key
|
||||
data = {"relationships": []}
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should have no entity contexts
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Test missing relationships key
|
||||
data = {"definitions": []}
|
||||
# Triples should be empty
|
||||
assert len(triples) == 0
|
||||
|
||||
def test_process_extraction_data_unknown_types_ignored(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with unknown type values"""
|
||||
data = [
|
||||
{"type": "definition", "entity": "Valid", "definition": "Valid def"},
|
||||
{"type": "unknown_type", "foo": "bar"}, # Unknown type - should be ignored
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Test completely missing keys
|
||||
data = {}
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Should process valid items and ignore unknown types
|
||||
assert len(entity_contexts) == 1 # Only the definition creates entity context
|
||||
|
||||
def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with malformed entries"""
|
||||
# Test definition missing required fields
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test"}, # Missing definition
|
||||
{"definition": "Test def"} # Missing entity
|
||||
],
|
||||
"relationships": [
|
||||
{"subject": "A", "predicate": "rel"}, # Missing object
|
||||
{"subject": "B", "object": "C"} # Missing predicate
|
||||
]
|
||||
}
|
||||
|
||||
# Test items missing required fields - should raise KeyError
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test"}, # Missing definition
|
||||
]
|
||||
|
||||
# Should handle gracefully or raise appropriate errors
|
||||
with pytest.raises(KeyError):
|
||||
agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
|
@ -340,17 +327,17 @@ class TestAgentKgExtractor:
|
|||
async def test_emit_triples(self, agent_extractor, sample_metadata):
|
||||
"""Test emitting triples to publisher"""
|
||||
mock_publisher = AsyncMock()
|
||||
|
||||
|
||||
test_triples = [
|
||||
Triple(
|
||||
s=Value(value="test:subject", is_uri=True),
|
||||
p=Value(value="test:predicate", is_uri=True),
|
||||
o=Value(value="test object", is_uri=False)
|
||||
s=Term(type=IRI, iri="test:subject"),
|
||||
p=Term(type=IRI, iri="test:predicate"),
|
||||
o=Term(type=LITERAL, value="test object")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples)
|
||||
|
||||
|
||||
mock_publisher.send.assert_called_once()
|
||||
sent_triples = mock_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
|
|
@ -361,22 +348,22 @@ class TestAgentKgExtractor:
|
|||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_triples.metadata.metadata == []
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.value == "test:subject"
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_entity_contexts(self, agent_extractor, sample_metadata):
|
||||
"""Test emitting entity contexts to publisher"""
|
||||
mock_publisher = AsyncMock()
|
||||
|
||||
|
||||
test_contexts = [
|
||||
EntityContext(
|
||||
entity=Value(value="test:entity", is_uri=True),
|
||||
entity=Term(type=IRI, iri="test:entity"),
|
||||
context="Test context"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts)
|
||||
|
||||
|
||||
mock_publisher.send.assert_called_once()
|
||||
sent_contexts = mock_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_contexts, EntityContexts)
|
||||
|
|
@ -387,7 +374,7 @@ class TestAgentKgExtractor:
|
|||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_contexts.metadata.metadata == []
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.value == "test:entity"
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
||||
def test_agent_extractor_initialization_params(self):
|
||||
"""Test agent extractor parameter validation"""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import urllib.parse
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
|
@ -32,11 +32,11 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
# Set up the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
||||
|
||||
return extractor
|
||||
|
||||
def test_to_uri_special_characters(self, agent_extractor):
|
||||
|
|
@ -85,146 +85,116 @@ class TestAgentKgExtractionEdgeCases:
|
|||
# Verify the URI is properly encoded
|
||||
assert unicode_text not in uri # Original unicode should be encoded
|
||||
|
||||
def test_parse_json_whitespace_variations(self, agent_extractor):
|
||||
"""Test JSON parsing with various whitespace patterns"""
|
||||
# Test JSON with different whitespace patterns
|
||||
def test_parse_jsonl_whitespace_variations(self, agent_extractor):
|
||||
"""Test JSONL parsing with various whitespace patterns"""
|
||||
# Test JSONL with different whitespace patterns
|
||||
test_cases = [
|
||||
# Extra whitespace around code blocks
|
||||
" ```json\n{\"test\": true}\n``` ",
|
||||
# Tabs and mixed whitespace
|
||||
"\t\t```json\n\t{\"test\": true}\n\t```\t",
|
||||
# Multiple newlines
|
||||
"\n\n\n```json\n\n{\"test\": true}\n\n```\n\n",
|
||||
# JSON without code blocks but with whitespace
|
||||
" {\"test\": true} ",
|
||||
# Mixed line endings
|
||||
"```json\r\n{\"test\": true}\r\n```",
|
||||
' ```json\n{"type": "definition", "entity": "test", "definition": "def"}\n``` ',
|
||||
# Multiple newlines between lines
|
||||
'{"type": "definition", "entity": "A", "definition": "def A"}\n\n\n{"type": "definition", "entity": "B", "definition": "def B"}',
|
||||
# JSONL without code blocks but with whitespace
|
||||
' {"type": "definition", "entity": "test", "definition": "def"} ',
|
||||
]
|
||||
|
||||
for response in test_cases:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert result == {"test": True}
|
||||
|
||||
def test_parse_json_code_block_variations(self, agent_extractor):
|
||||
"""Test JSON parsing with different code block formats"""
|
||||
for response in test_cases:
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
assert len(result) >= 1
|
||||
assert result[0].get("type") == "definition"
|
||||
|
||||
def test_parse_jsonl_code_block_variations(self, agent_extractor):
|
||||
"""Test JSONL parsing with different code block formats"""
|
||||
test_cases = [
|
||||
# Standard json code block
|
||||
"```json\n{\"valid\": true}\n```",
|
||||
'```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# jsonl code block
|
||||
'```jsonl\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# Code block without language
|
||||
"```\n{\"valid\": true}\n```",
|
||||
# Uppercase JSON
|
||||
"```JSON\n{\"valid\": true}\n```",
|
||||
# Mixed case
|
||||
"```Json\n{\"valid\": true}\n```",
|
||||
# Multiple code blocks (should take first one)
|
||||
"```json\n{\"first\": true}\n```\n```json\n{\"second\": true}\n```",
|
||||
# Code block with extra content
|
||||
"Here's the result:\n```json\n{\"valid\": true}\n```\nDone!",
|
||||
'```\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# Code block with extra content before/after
|
||||
'Here\'s the result:\n```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```\nDone!',
|
||||
]
|
||||
|
||||
|
||||
for i, response in enumerate(test_cases):
|
||||
try:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert result.get("valid") == True or result.get("first") == True
|
||||
except json.JSONDecodeError:
|
||||
# Some cases may fail due to regex extraction issues
|
||||
# This documents current behavior - the regex may not match all cases
|
||||
print(f"Case {i} failed JSON parsing: {response[:50]}...")
|
||||
pass
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
assert len(result) >= 1, f"Case {i} failed"
|
||||
assert result[0].get("entity") == "A"
|
||||
|
||||
def test_parse_json_malformed_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing with malformed code block formats"""
|
||||
# These should still work by falling back to treating entire text as JSON
|
||||
test_cases = [
|
||||
# Unclosed code block
|
||||
"```json\n{\"test\": true}",
|
||||
# No opening backticks
|
||||
"{\"test\": true}\n```",
|
||||
# Wrong number of backticks
|
||||
"`json\n{\"test\": true}\n`",
|
||||
# Nested backticks (should handle gracefully)
|
||||
"```json\n{\"code\": \"```\", \"test\": true}\n```",
|
||||
]
|
||||
|
||||
for response in test_cases:
|
||||
try:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert "test" in result # Should successfully parse
|
||||
except json.JSONDecodeError:
|
||||
# This is also acceptable for malformed cases
|
||||
pass
|
||||
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
|
||||
"""Test JSONL parsing with truncated responses"""
|
||||
# Simulates LLM output being cut off mid-line
|
||||
response = '''{"type": "definition", "entity": "Complete1", "definition": "Full definition"}
|
||||
{"type": "definition", "entity": "Complete2", "definition": "Another full def"}
|
||||
{"type": "definition", "entity": "Trunca'''
|
||||
|
||||
def test_parse_json_large_responses(self, agent_extractor):
|
||||
"""Test JSON parsing with very large responses"""
|
||||
# Create a large JSON structure
|
||||
large_data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": f"Entity {i}",
|
||||
"definition": f"Definition {i} " + "with more content " * 100
|
||||
}
|
||||
for i in range(100)
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": f"Subject {i}",
|
||||
"predicate": f"predicate_{i}",
|
||||
"object": f"Object {i}",
|
||||
"object-entity": i % 2 == 0
|
||||
}
|
||||
for i in range(50)
|
||||
]
|
||||
}
|
||||
|
||||
large_json_str = json.dumps(large_data)
|
||||
response = f"```json\n{large_json_str}\n```"
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert len(result["definitions"]) == 100
|
||||
assert len(result["relationships"]) == 50
|
||||
assert result["definitions"][0]["entity"] == "Entity 0"
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 2 valid objects, the truncated line is skipped
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "Complete1"
|
||||
assert result[1]["entity"] == "Complete2"
|
||||
|
||||
def test_parse_jsonl_large_responses(self, agent_extractor):
|
||||
"""Test JSONL parsing with very large responses"""
|
||||
# Create a large JSONL response
|
||||
lines = []
|
||||
for i in range(100):
|
||||
lines.append(json.dumps({
|
||||
"type": "definition",
|
||||
"entity": f"Entity {i}",
|
||||
"definition": f"Definition {i} " + "with more content " * 100
|
||||
}))
|
||||
for i in range(50):
|
||||
lines.append(json.dumps({
|
||||
"type": "relationship",
|
||||
"subject": f"Subject {i}",
|
||||
"predicate": f"predicate_{i}",
|
||||
"object": f"Object {i}",
|
||||
"object-entity": i % 2 == 0
|
||||
}))
|
||||
|
||||
response = f"```json\n{chr(10).join(lines)}\n```"
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
definitions = [r for r in result if r.get("type") == "definition"]
|
||||
relationships = [r for r in result if r.get("type") == "relationship"]
|
||||
|
||||
assert len(definitions) == 100
|
||||
assert len(relationships) == 50
|
||||
assert definitions[0]["entity"] == "Entity 0"
|
||||
|
||||
def test_process_extraction_data_empty_metadata(self, agent_extractor):
|
||||
"""Test processing with empty or minimal metadata"""
|
||||
# Test with None metadata - may not raise AttributeError depending on implementation
|
||||
try:
|
||||
triples, contexts = agent_extractor.process_extraction_data(
|
||||
{"definitions": [], "relationships": []},
|
||||
None
|
||||
)
|
||||
triples, contexts = agent_extractor.process_extraction_data([], None)
|
||||
# If it doesn't raise, check the results
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
except (AttributeError, TypeError):
|
||||
# This is expected behavior when metadata is None
|
||||
pass
|
||||
|
||||
|
||||
# Test with metadata without ID
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
triples, contexts = agent_extractor.process_extraction_data(
|
||||
{"definitions": [], "relationships": []},
|
||||
metadata
|
||||
)
|
||||
triples, contexts = agent_extractor.process_extraction_data([], metadata)
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
|
||||
|
||||
# Test with metadata with empty string ID
|
||||
metadata = Metadata(id="", metadata=[])
|
||||
data = {
|
||||
"definitions": [{"entity": "Test", "definition": "Test def"}],
|
||||
"relationships": []
|
||||
}
|
||||
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should not create subject-of triples when ID is empty string
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
|
||||
def test_process_extraction_data_special_entity_names(self, agent_extractor):
|
||||
"""Test processing with special characters in entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
|
||||
special_entities = [
|
||||
"Entity with spaces",
|
||||
"Entity & Co.",
|
||||
|
|
@ -237,71 +207,62 @@ class TestAgentKgExtractionEdgeCases:
|
|||
"Quotes: \"test\"",
|
||||
"Parentheses: (test)",
|
||||
]
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": entity, "definition": f"Definition for {entity}"}
|
||||
for entity in special_entities
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": entity, "definition": f"Definition for {entity}"}
|
||||
for entity in special_entities
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Verify all entities were processed
|
||||
assert len(contexts) == len(special_entities)
|
||||
|
||||
|
||||
# Verify URIs were properly encoded
|
||||
for i, entity in enumerate(special_entities):
|
||||
expected_uri = f"{TRUSTGRAPH_ENTITIES}{urllib.parse.quote(entity)}"
|
||||
assert contexts[i].entity.value == expected_uri
|
||||
assert contexts[i].entity.iri == expected_uri
|
||||
|
||||
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
|
||||
"""Test processing with very long entity definitions"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
|
||||
# Create very long definition
|
||||
long_definition = "This is a very long definition. " * 1000
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test Entity", "definition": long_definition}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle long definitions without issues
|
||||
assert len(contexts) == 1
|
||||
assert contexts[0].context == long_definition
|
||||
|
||||
|
||||
# Find definition triple
|
||||
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
|
||||
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
|
||||
assert def_triple is not None
|
||||
assert def_triple.o.value == long_definition
|
||||
|
||||
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
|
||||
"""Test processing with duplicate entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Machine Learning", "definition": "First definition"},
|
||||
{"entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
|
||||
{"entity": "AI", "definition": "AI definition"},
|
||||
{"entity": "AI", "definition": "Another AI definition"}, # Duplicate
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
|
||||
{"type": "definition", "entity": "AI", "definition": "AI definition"},
|
||||
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should process all entries (including duplicates)
|
||||
assert len(contexts) == 4
|
||||
|
||||
|
||||
# Check that both definitions for "Machine Learning" are present
|
||||
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.value]
|
||||
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.iri]
|
||||
assert len(ml_contexts) == 2
|
||||
assert ml_contexts[0].context == "First definition"
|
||||
assert ml_contexts[1].context == "Second definition"
|
||||
|
|
@ -309,49 +270,44 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_empty_strings(self, agent_extractor):
|
||||
"""Test processing with empty strings in data"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "", "definition": "Definition for empty entity"},
|
||||
{"entity": "Valid Entity", "definition": ""},
|
||||
{"entity": " ", "definition": " "}, # Whitespace only
|
||||
],
|
||||
"relationships": [
|
||||
{"subject": "", "predicate": "test", "object": "test", "object-entity": True},
|
||||
{"subject": "test", "predicate": "", "object": "test", "object-entity": True},
|
||||
{"subject": "test", "predicate": "test", "object": "", "object-entity": True},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
|
||||
{"type": "definition", "entity": "Valid Entity", "definition": ""},
|
||||
{"type": "definition", "entity": " ", "definition": " "}, # Whitespace only
|
||||
{"type": "relationship", "subject": "", "predicate": "test", "object": "test", "object-entity": True},
|
||||
{"type": "relationship", "subject": "test", "predicate": "", "object": "test", "object-entity": True},
|
||||
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle empty strings by creating URIs (even if empty)
|
||||
assert len(contexts) == 3
|
||||
|
||||
|
||||
# Empty entity should create empty URI after encoding
|
||||
empty_entity_context = next((ec for ec in contexts if ec.entity.value == TRUSTGRAPH_ENTITIES), None)
|
||||
empty_entity_context = next((ec for ec in contexts if ec.entity.iri == TRUSTGRAPH_ENTITIES), None)
|
||||
assert empty_entity_context is not None
|
||||
|
||||
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
|
||||
"""Test processing when definitions contain JSON-like strings"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "JSON Entity",
|
||||
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
|
||||
},
|
||||
{
|
||||
"entity": "Array Entity",
|
||||
"definition": 'Contains array: [1, 2, 3, "string"]'
|
||||
}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "JSON Entity",
|
||||
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
|
||||
},
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Array Entity",
|
||||
"definition": 'Contains array: [1, 2, 3, "string"]'
|
||||
}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle JSON strings in definitions without parsing them
|
||||
assert len(contexts) == 2
|
||||
assert '{"key": "value"' in contexts[0].context
|
||||
|
|
@ -360,32 +316,29 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
|
||||
"""Test processing with various boolean values for object-entity"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
# Explicit True
|
||||
{"subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
|
||||
# Explicit False
|
||||
{"subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
|
||||
# Missing object-entity (should default to True based on code)
|
||||
{"subject": "A", "predicate": "rel3", "object": "C"},
|
||||
# String "true" (should be treated as truthy)
|
||||
{"subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
|
||||
# String "false" (should be treated as truthy in Python)
|
||||
{"subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
|
||||
# Number 0 (falsy)
|
||||
{"subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
|
||||
# Number 1 (truthy)
|
||||
{"subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
# Explicit True
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
|
||||
# Explicit False
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
|
||||
# Missing object-entity (should default to True based on code)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel3", "object": "C"},
|
||||
# String "true" (should be treated as truthy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
|
||||
# String "false" (should be treated as truthy in Python)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
|
||||
# Number 0 (falsy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
|
||||
# Number 1 (truthy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should process all relationships
|
||||
# Note: The current implementation has some logic issues that these tests document
|
||||
assert len([t for t in triples if t.p.value != RDF_LABEL and t.p.value != SUBJECT_OF]) >= 7
|
||||
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_empty_collections(self, agent_extractor):
|
||||
|
|
@ -437,41 +390,40 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
|
||||
"""Test performance with large extraction datasets"""
|
||||
metadata = Metadata(id="large-doc", metadata=[])
|
||||
|
||||
# Create large dataset
|
||||
|
||||
# Create large dataset in JSONL format
|
||||
num_definitions = 1000
|
||||
num_relationships = 2000
|
||||
|
||||
large_data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": f"Entity_{i:04d}",
|
||||
"definition": f"Definition for entity {i} with some detailed explanation."
|
||||
}
|
||||
for i in range(num_definitions)
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": f"Entity_{i % num_definitions:04d}",
|
||||
"predicate": f"predicate_{i % 10}",
|
||||
"object": f"Entity_{(i + 1) % num_definitions:04d}",
|
||||
"object-entity": True
|
||||
}
|
||||
for i in range(num_relationships)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
large_data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": f"Entity_{i:04d}",
|
||||
"definition": f"Definition for entity {i} with some detailed explanation."
|
||||
}
|
||||
for i in range(num_definitions)
|
||||
] + [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": f"Entity_{i % num_definitions:04d}",
|
||||
"predicate": f"predicate_{i % 10}",
|
||||
"object": f"Entity_{(i + 1) % num_definitions:04d}",
|
||||
"object-entity": True
|
||||
}
|
||||
for i in range(num_relationships)
|
||||
]
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
|
||||
# Should complete within reasonable time (adjust threshold as needed)
|
||||
assert processing_time < 10.0 # 10 seconds threshold
|
||||
|
||||
|
||||
# Verify results
|
||||
assert len(contexts) == num_definitions
|
||||
# Triples include labels, definitions, relationships, and subject-of relations
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ processing graph structures, and performing graph operations.
|
|||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from .conftest import Triple, Value, Metadata
|
||||
from .conftest import Triple, Metadata
|
||||
from collections import defaultdict, deque
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def cities_schema():
|
|||
def validator():
|
||||
"""Create a mock processor with just the validation method"""
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
from trustgraph.extract.kg.rows.processor import Processor
|
||||
|
||||
# Create a mock processor
|
||||
mock_processor = MagicMock()
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@
|
|||
Unit tests for triple construction logic
|
||||
|
||||
Tests the core business logic for constructing RDF triples from extracted
|
||||
entities and relationships, including URI generation, Value object creation,
|
||||
entities and relationships, including URI generation, Term object creation,
|
||||
and triple validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from .conftest import Triple, Triples, Value, Metadata
|
||||
from .conftest import Triple, Triples, Term, Metadata, IRI, LITERAL
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
|
|
@ -48,80 +48,82 @@ class TestTripleConstructionLogic:
|
|||
generated_uri = generate_uri(text, entity_type)
|
||||
assert generated_uri == expected_uri, f"URI generation failed for '{text}'"
|
||||
|
||||
def test_value_object_creation(self):
|
||||
"""Test creation of Value objects for subjects, predicates, and objects"""
|
||||
def test_term_object_creation(self):
|
||||
"""Test creation of Term objects for subjects, predicates, and objects"""
|
||||
# Arrange
|
||||
def create_value_object(text, is_uri, value_type=""):
|
||||
return Value(
|
||||
value=text,
|
||||
is_uri=is_uri,
|
||||
type=value_type
|
||||
)
|
||||
|
||||
def create_term_object(text, is_uri, datatype=""):
|
||||
if is_uri:
|
||||
return Term(type=IRI, iri=text)
|
||||
else:
|
||||
return Term(type=LITERAL, value=text, datatype=datatype if datatype else None)
|
||||
|
||||
test_cases = [
|
||||
("http://trustgraph.ai/kg/person/john-smith", True, ""),
|
||||
("John Smith", False, "string"),
|
||||
("42", False, "integer"),
|
||||
("http://schema.org/worksFor", True, "")
|
||||
]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
for value_text, is_uri, value_type in test_cases:
|
||||
value_obj = create_value_object(value_text, is_uri, value_type)
|
||||
|
||||
assert isinstance(value_obj, Value)
|
||||
assert value_obj.value == value_text
|
||||
assert value_obj.is_uri == is_uri
|
||||
assert value_obj.type == value_type
|
||||
for value_text, is_uri, datatype in test_cases:
|
||||
term_obj = create_term_object(value_text, is_uri, datatype)
|
||||
|
||||
assert isinstance(term_obj, Term)
|
||||
if is_uri:
|
||||
assert term_obj.type == IRI
|
||||
assert term_obj.iri == value_text
|
||||
else:
|
||||
assert term_obj.type == LITERAL
|
||||
assert term_obj.value == value_text
|
||||
|
||||
def test_triple_construction_from_relationship(self):
|
||||
"""Test constructing Triple objects from relationships"""
|
||||
# Arrange
|
||||
relationship = {
|
||||
"subject": "John Smith",
|
||||
"predicate": "works_for",
|
||||
"predicate": "works_for",
|
||||
"object": "OpenAI",
|
||||
"subject_type": "PERSON",
|
||||
"object_type": "ORG"
|
||||
}
|
||||
|
||||
|
||||
def construct_triple(relationship, uri_base="http://trustgraph.ai/kg"):
|
||||
# Generate URIs
|
||||
subject_uri = f"{uri_base}/person/{relationship['subject'].lower().replace(' ', '-')}"
|
||||
object_uri = f"{uri_base}/org/{relationship['object'].lower().replace(' ', '-')}"
|
||||
|
||||
|
||||
# Map predicate to schema.org URI
|
||||
predicate_mappings = {
|
||||
"works_for": "http://schema.org/worksFor",
|
||||
"located_in": "http://schema.org/location",
|
||||
"developed": "http://schema.org/creator"
|
||||
}
|
||||
predicate_uri = predicate_mappings.get(relationship["predicate"],
|
||||
predicate_uri = predicate_mappings.get(relationship["predicate"],
|
||||
f"{uri_base}/predicate/{relationship['predicate']}")
|
||||
|
||||
# Create Value objects
|
||||
subject_value = Value(value=subject_uri, is_uri=True, type="")
|
||||
predicate_value = Value(value=predicate_uri, is_uri=True, type="")
|
||||
object_value = Value(value=object_uri, is_uri=True, type="")
|
||||
|
||||
|
||||
# Create Term objects
|
||||
subject_term = Term(type=IRI, iri=subject_uri)
|
||||
predicate_term = Term(type=IRI, iri=predicate_uri)
|
||||
object_term = Term(type=IRI, iri=object_uri)
|
||||
|
||||
# Create Triple
|
||||
return Triple(
|
||||
s=subject_value,
|
||||
p=predicate_value,
|
||||
o=object_value
|
||||
s=subject_term,
|
||||
p=predicate_term,
|
||||
o=object_term
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
triple = construct_triple(relationship)
|
||||
|
||||
|
||||
# Assert
|
||||
assert isinstance(triple, Triple)
|
||||
assert triple.s.value == "http://trustgraph.ai/kg/person/john-smith"
|
||||
assert triple.s.is_uri is True
|
||||
assert triple.p.value == "http://schema.org/worksFor"
|
||||
assert triple.p.is_uri is True
|
||||
assert triple.o.value == "http://trustgraph.ai/kg/org/openai"
|
||||
assert triple.o.is_uri is True
|
||||
assert triple.s.iri == "http://trustgraph.ai/kg/person/john-smith"
|
||||
assert triple.s.type == IRI
|
||||
assert triple.p.iri == "http://schema.org/worksFor"
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == "http://trustgraph.ai/kg/org/openai"
|
||||
assert triple.o.type == IRI
|
||||
|
||||
def test_literal_value_handling(self):
|
||||
"""Test handling of literal values vs URI values"""
|
||||
|
|
@ -132,10 +134,10 @@ class TestTripleConstructionLogic:
|
|||
("John Smith", "email", "john@example.com", False), # Literal email
|
||||
("John Smith", "worksFor", "http://trustgraph.ai/kg/org/openai", True) # URI reference
|
||||
]
|
||||
|
||||
|
||||
def create_triple_with_literal(subject_uri, predicate, object_value, object_is_uri):
|
||||
subject_val = Value(value=subject_uri, is_uri=True, type="")
|
||||
|
||||
subject_term = Term(type=IRI, iri=subject_uri)
|
||||
|
||||
# Determine predicate URI
|
||||
predicate_mappings = {
|
||||
"name": "http://schema.org/name",
|
||||
|
|
@ -144,32 +146,37 @@ class TestTripleConstructionLogic:
|
|||
"worksFor": "http://schema.org/worksFor"
|
||||
}
|
||||
predicate_uri = predicate_mappings.get(predicate, f"http://trustgraph.ai/kg/predicate/{predicate}")
|
||||
predicate_val = Value(value=predicate_uri, is_uri=True, type="")
|
||||
|
||||
# Create object value with appropriate type
|
||||
object_type = ""
|
||||
if not object_is_uri:
|
||||
predicate_term = Term(type=IRI, iri=predicate_uri)
|
||||
|
||||
# Create object term with appropriate type
|
||||
if object_is_uri:
|
||||
object_term = Term(type=IRI, iri=object_value)
|
||||
else:
|
||||
datatype = None
|
||||
if predicate == "age":
|
||||
object_type = "integer"
|
||||
datatype = "integer"
|
||||
elif predicate in ["name", "email"]:
|
||||
object_type = "string"
|
||||
|
||||
object_val = Value(value=object_value, is_uri=object_is_uri, type=object_type)
|
||||
|
||||
return Triple(s=subject_val, p=predicate_val, o=object_val)
|
||||
|
||||
datatype = "string"
|
||||
object_term = Term(type=LITERAL, value=object_value, datatype=datatype)
|
||||
|
||||
return Triple(s=subject_term, p=predicate_term, o=object_term)
|
||||
|
||||
# Act & Assert
|
||||
for subject_uri, predicate, object_value, object_is_uri in test_data:
|
||||
subject_full_uri = "http://trustgraph.ai/kg/person/john-smith"
|
||||
triple = create_triple_with_literal(subject_full_uri, predicate, object_value, object_is_uri)
|
||||
|
||||
assert triple.o.is_uri == object_is_uri
|
||||
assert triple.o.value == object_value
|
||||
|
||||
|
||||
if object_is_uri:
|
||||
assert triple.o.type == IRI
|
||||
assert triple.o.iri == object_value
|
||||
else:
|
||||
assert triple.o.type == LITERAL
|
||||
assert triple.o.value == object_value
|
||||
|
||||
if predicate == "age":
|
||||
assert triple.o.type == "integer"
|
||||
assert triple.o.datatype == "integer"
|
||||
elif predicate in ["name", "email"]:
|
||||
assert triple.o.type == "string"
|
||||
assert triple.o.datatype == "string"
|
||||
|
||||
def test_namespace_management(self):
|
||||
"""Test namespace prefix management and expansion"""
|
||||
|
|
@ -216,63 +223,74 @@ class TestTripleConstructionLogic:
|
|||
def test_triple_validation(self):
|
||||
"""Test triple validation rules"""
|
||||
# Arrange
|
||||
def get_term_value(term):
|
||||
"""Extract value from a Term"""
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
else:
|
||||
return term.value
|
||||
|
||||
def validate_triple(triple):
|
||||
errors = []
|
||||
|
||||
|
||||
# Check required components
|
||||
if not triple.s or not triple.s.value:
|
||||
s_val = get_term_value(triple.s) if triple.s else None
|
||||
p_val = get_term_value(triple.p) if triple.p else None
|
||||
o_val = get_term_value(triple.o) if triple.o else None
|
||||
|
||||
if not triple.s or not s_val:
|
||||
errors.append("Missing or empty subject")
|
||||
|
||||
if not triple.p or not triple.p.value:
|
||||
|
||||
if not triple.p or not p_val:
|
||||
errors.append("Missing or empty predicate")
|
||||
|
||||
if not triple.o or not triple.o.value:
|
||||
|
||||
if not triple.o or not o_val:
|
||||
errors.append("Missing or empty object")
|
||||
|
||||
|
||||
# Check URI validity for URI values
|
||||
uri_pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||
|
||||
if triple.s.is_uri and not re.match(uri_pattern, triple.s.value):
|
||||
|
||||
if triple.s.type == IRI and not re.match(uri_pattern, triple.s.iri or ""):
|
||||
errors.append("Invalid subject URI format")
|
||||
|
||||
if triple.p.is_uri and not re.match(uri_pattern, triple.p.value):
|
||||
|
||||
if triple.p.type == IRI and not re.match(uri_pattern, triple.p.iri or ""):
|
||||
errors.append("Invalid predicate URI format")
|
||||
|
||||
if triple.o.is_uri and not re.match(uri_pattern, triple.o.value):
|
||||
|
||||
if triple.o.type == IRI and not re.match(uri_pattern, triple.o.iri or ""):
|
||||
errors.append("Invalid object URI format")
|
||||
|
||||
|
||||
# Predicates should typically be URIs
|
||||
if not triple.p.is_uri:
|
||||
if triple.p.type != IRI:
|
||||
errors.append("Predicate should be a URI")
|
||||
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
# Test valid triple
|
||||
valid_triple = Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John Smith", is_uri=False, type="string")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith", datatype="string")
|
||||
)
|
||||
|
||||
|
||||
# Test invalid triples
|
||||
invalid_triples = [
|
||||
Triple(s=Value(value="", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John", is_uri=False, type="")), # Empty subject
|
||||
|
||||
Triple(s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="name", is_uri=False, type=""), # Non-URI predicate
|
||||
o=Value(value="John", is_uri=False, type="")),
|
||||
|
||||
Triple(s=Value(value="invalid-uri", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John", is_uri=False, type="")) # Invalid URI format
|
||||
Triple(s=Term(type=IRI, iri=""),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John")), # Empty subject
|
||||
|
||||
Triple(s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=LITERAL, value="name"), # Non-URI predicate
|
||||
o=Term(type=LITERAL, value="John")),
|
||||
|
||||
Triple(s=Term(type=IRI, iri="invalid-uri"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John")) # Invalid URI format
|
||||
]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
is_valid, errors = validate_triple(valid_triple)
|
||||
assert is_valid, f"Valid triple failed validation: {errors}"
|
||||
|
||||
|
||||
for invalid_triple in invalid_triples:
|
||||
is_valid, errors = validate_triple(invalid_triple)
|
||||
assert not is_valid, f"Invalid triple passed validation: {invalid_triple}"
|
||||
|
|
@ -286,97 +304,97 @@ class TestTripleConstructionLogic:
|
|||
{"text": "OpenAI", "type": "ORG"},
|
||||
{"text": "San Francisco", "type": "PLACE"}
|
||||
]
|
||||
|
||||
|
||||
relationships = [
|
||||
{"subject": "John Smith", "predicate": "works_for", "object": "OpenAI"},
|
||||
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"}
|
||||
]
|
||||
|
||||
|
||||
def construct_triple_batch(entities, relationships, document_id="doc-1"):
|
||||
triples = []
|
||||
|
||||
|
||||
# Create type triples for entities
|
||||
for entity in entities:
|
||||
entity_uri = f"http://trustgraph.ai/kg/{entity['type'].lower()}/{entity['text'].lower().replace(' ', '-')}"
|
||||
type_uri = f"http://trustgraph.ai/kg/type/{entity['type']}"
|
||||
|
||||
|
||||
type_triple = Triple(
|
||||
s=Value(value=entity_uri, is_uri=True, type=""),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True, type=""),
|
||||
o=Value(value=type_uri, is_uri=True, type="")
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri=type_uri)
|
||||
)
|
||||
triples.append(type_triple)
|
||||
|
||||
|
||||
# Create relationship triples
|
||||
for rel in relationships:
|
||||
subject_uri = f"http://trustgraph.ai/kg/entity/{rel['subject'].lower().replace(' ', '-')}"
|
||||
object_uri = f"http://trustgraph.ai/kg/entity/{rel['object'].lower().replace(' ', '-')}"
|
||||
predicate_uri = f"http://schema.org/{rel['predicate'].replace('_', '')}"
|
||||
|
||||
|
||||
rel_triple = Triple(
|
||||
s=Value(value=subject_uri, is_uri=True, type=""),
|
||||
p=Value(value=predicate_uri, is_uri=True, type=""),
|
||||
o=Value(value=object_uri, is_uri=True, type="")
|
||||
s=Term(type=IRI, iri=subject_uri),
|
||||
p=Term(type=IRI, iri=predicate_uri),
|
||||
o=Term(type=IRI, iri=object_uri)
|
||||
)
|
||||
triples.append(rel_triple)
|
||||
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
# Act
|
||||
triples = construct_triple_batch(entities, relationships)
|
||||
|
||||
|
||||
# Assert
|
||||
assert len(triples) == len(entities) + len(relationships) # Type triples + relationship triples
|
||||
|
||||
|
||||
# Check that all triples are valid Triple objects
|
||||
for triple in triples:
|
||||
assert isinstance(triple, Triple)
|
||||
assert triple.s.value != ""
|
||||
assert triple.p.value != ""
|
||||
assert triple.o.value != ""
|
||||
assert triple.s.iri != ""
|
||||
assert triple.p.iri != ""
|
||||
assert triple.o.iri != ""
|
||||
|
||||
def test_triples_batch_object_creation(self):
|
||||
"""Test creating Triples batch objects with metadata"""
|
||||
# Arrange
|
||||
sample_triples = [
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John Smith", is_uri=False, type="string")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith", datatype="string")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/worksFor", is_uri=True, type=""),
|
||||
o=Value(value="http://trustgraph.ai/kg/org/openai", is_uri=True, type="")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/worksFor"),
|
||||
o=Term(type=IRI, iri="http://trustgraph.ai/kg/org/openai")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
triples_batch = Triples(
|
||||
metadata=metadata,
|
||||
triples=sample_triples
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert isinstance(triples_batch, Triples)
|
||||
assert triples_batch.metadata.id == "test-doc-123"
|
||||
assert triples_batch.metadata.user == "test_user"
|
||||
assert triples_batch.metadata.collection == "test_collection"
|
||||
assert len(triples_batch.triples) == 2
|
||||
|
||||
|
||||
# Check that triples are properly embedded
|
||||
for triple in triples_batch.triples:
|
||||
assert isinstance(triple, Triple)
|
||||
assert isinstance(triple.s, Value)
|
||||
assert isinstance(triple.p, Value)
|
||||
assert isinstance(triple.o, Value)
|
||||
assert isinstance(triple.s, Term)
|
||||
assert isinstance(triple.p, Term)
|
||||
assert isinstance(triple.o, Term)
|
||||
|
||||
def test_uri_collision_handling(self):
|
||||
"""Test handling of URI collisions and duplicate detection"""
|
||||
|
|
|
|||
|
|
@ -339,7 +339,250 @@ class TestPromptManager:
|
|||
"""Test PromptManager with minimal configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config({}) # Empty config
|
||||
|
||||
|
||||
assert pm.config.system_template == "Be helpful." # Default system
|
||||
assert pm.terms == {} # Default empty terms
|
||||
assert len(pm.prompts) == 0
|
||||
assert len(pm.prompts) == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptManagerJsonl:
|
||||
"""Unit tests for PromptManager JSONL functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def jsonl_config(self):
|
||||
"""Configuration with JSONL response type prompts"""
|
||||
return {
|
||||
"system": json.dumps("You are an extraction assistant."),
|
||||
"template-index": json.dumps(["extract_simple", "extract_with_schema", "extract_mixed"]),
|
||||
"template.extract_simple": json.dumps({
|
||||
"prompt": "Extract entities from: {{ text }}",
|
||||
"response-type": "jsonl"
|
||||
}),
|
||||
"template.extract_with_schema": json.dumps({
|
||||
"prompt": "Extract definitions from: {{ text }}",
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string"},
|
||||
"definition": {"type": "string"}
|
||||
},
|
||||
"required": ["entity", "definition"]
|
||||
}
|
||||
}),
|
||||
"template.extract_mixed": json.dumps({
|
||||
"prompt": "Extract knowledge from: {{ text }}",
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "definition"},
|
||||
"entity": {"type": "string"},
|
||||
"definition": {"type": "string"}
|
||||
},
|
||||
"required": ["type", "entity", "definition"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "relationship"},
|
||||
"subject": {"type": "string"},
|
||||
"predicate": {"type": "string"},
|
||||
"object": {"type": "string"}
|
||||
},
|
||||
"required": ["type", "subject", "predicate", "object"]
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(self, jsonl_config):
|
||||
"""Create a PromptManager with JSONL configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config(jsonl_config)
|
||||
return pm
|
||||
|
||||
def test_parse_jsonl_basic(self, prompt_manager):
|
||||
"""Test basic JSONL parsing"""
|
||||
text = '{"entity": "cat", "definition": "A small furry animal"}\n{"entity": "dog", "definition": "A loyal pet"}'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "cat"
|
||||
assert result[1]["entity"] == "dog"
|
||||
|
||||
def test_parse_jsonl_with_empty_lines(self, prompt_manager):
|
||||
"""Test JSONL parsing skips empty lines"""
|
||||
text = '{"entity": "cat"}\n\n\n{"entity": "dog"}\n'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_parse_jsonl_with_markdown_fences(self, prompt_manager):
|
||||
"""Test JSONL parsing strips markdown code fences"""
|
||||
text = '''```json
|
||||
{"entity": "cat", "definition": "A furry animal"}
|
||||
{"entity": "dog", "definition": "A loyal pet"}
|
||||
```'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "cat"
|
||||
assert result[1]["entity"] == "dog"
|
||||
|
||||
def test_parse_jsonl_with_jsonl_fence(self, prompt_manager):
|
||||
"""Test JSONL parsing strips jsonl-marked code fences"""
|
||||
text = '''```jsonl
|
||||
{"entity": "cat"}
|
||||
{"entity": "dog"}
|
||||
```'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_parse_jsonl_truncation_resilience(self, prompt_manager):
|
||||
"""Test JSONL parsing handles truncated final line"""
|
||||
text = '{"entity": "cat", "definition": "Complete"}\n{"entity": "dog", "defi'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
# Should get the first valid object, skip the truncated one
|
||||
assert len(result) == 1
|
||||
assert result[0]["entity"] == "cat"
|
||||
|
||||
def test_parse_jsonl_invalid_lines_skipped(self, prompt_manager):
|
||||
"""Test JSONL parsing skips invalid JSON lines"""
|
||||
text = '''{"entity": "valid1"}
|
||||
not json at all
|
||||
{"entity": "valid2"}
|
||||
{broken json
|
||||
{"entity": "valid3"}'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0]["entity"] == "valid1"
|
||||
assert result[1]["entity"] == "valid2"
|
||||
assert result[2]["entity"] == "valid3"
|
||||
|
||||
def test_parse_jsonl_empty_input(self, prompt_manager):
|
||||
"""Test JSONL parsing with empty input"""
|
||||
result = prompt_manager.parse_jsonl("")
|
||||
assert result == []
|
||||
|
||||
result = prompt_manager.parse_jsonl("\n\n\n")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_response(self, prompt_manager):
|
||||
"""Test invoking a prompt with JSONL response"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"entity": "photosynthesis", "definition": "Plant process"}\n{"entity": "mitosis", "definition": "Cell division"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Biology text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "photosynthesis"
|
||||
assert result[1]["entity"] == "mitosis"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_with_schema_validation(self, prompt_manager):
|
||||
"""Test JSONL response with schema validation"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"entity": "cat", "definition": "A pet"}\n{"entity": "dog", "definition": "Another pet"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_with_schema",
|
||||
{"text": "Animal text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all("entity" in obj and "definition" in obj for obj in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_schema_filters_invalid(self, prompt_manager):
|
||||
"""Test JSONL schema validation filters out invalid objects"""
|
||||
mock_llm = AsyncMock()
|
||||
# Second object is missing required 'definition' field
|
||||
mock_llm.return_value = '{"entity": "valid", "definition": "Has both fields"}\n{"entity": "invalid_missing_definition"}\n{"entity": "also_valid", "definition": "Complete"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_with_schema",
|
||||
{"text": "Test text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
# Only the two valid objects should be returned
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "valid"
|
||||
assert result[1]["entity"] == "also_valid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_mixed_types(self, prompt_manager):
|
||||
"""Test JSONL with discriminated union schema (oneOf)"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '''{"type": "definition", "entity": "DNA", "definition": "Genetic material"}
|
||||
{"type": "relationship", "subject": "DNA", "predicate": "found_in", "object": "nucleus"}
|
||||
{"type": "definition", "entity": "RNA", "definition": "Messenger molecule"}'''
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_mixed",
|
||||
{"text": "Biology text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Check definitions
|
||||
definitions = [r for r in result if r.get("type") == "definition"]
|
||||
assert len(definitions) == 2
|
||||
|
||||
# Check relationships
|
||||
relationships = [r for r in result if r.get("type") == "relationship"]
|
||||
assert len(relationships) == 1
|
||||
assert relationships[0]["subject"] == "DNA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_empty_result(self, prompt_manager):
|
||||
"""Test JSONL response that yields no valid objects"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = "No JSON here at all"
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Test"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_without_schema(self, prompt_manager):
|
||||
"""Test JSONL response without schema validation"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"any": "structure"}\n{"completely": "different"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Test"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"any": "structure"}
|
||||
assert result[1] == {"completely": "different"}
|
||||
|
|
@ -167,7 +167,7 @@ class TestFlowClient:
|
|||
expected_methods = [
|
||||
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||
'triples_query', 'objects_query'
|
||||
'triples_query', 'rows_query'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -216,7 +216,7 @@ class TestSocketClient:
|
|||
expected_methods = [
|
||||
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||
'triples_query', 'objects_query', 'mcp_tool'
|
||||
'triples_query', 'rows_query', 'mcp_tool'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -243,7 +243,7 @@ class TestBulkClient:
|
|||
'import_graph_embeddings',
|
||||
'import_document_embeddings',
|
||||
'import_entity_contexts',
|
||||
'import_objects'
|
||||
'import_rows'
|
||||
]
|
||||
|
||||
for method in import_methods:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Value, GraphEmbeddingsRequest
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -68,50 +68,50 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
|
|
@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are converted to Value objects
|
||||
# Verify results are converted to Term objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Value)
|
||||
assert result[0].value == "http://example.com/entity1"
|
||||
assert result[0].is_uri is True
|
||||
assert isinstance(result[1], Value)
|
||||
assert result[1].value == "http://example.com/entity2"
|
||||
assert result[1].is_uri is True
|
||||
assert isinstance(result[2], Value)
|
||||
assert isinstance(result[0], Term)
|
||||
assert result[0].iri == "http://example.com/entity1"
|
||||
assert result[0].type == IRI
|
||||
assert isinstance(result[1], Term)
|
||||
assert result[1].iri == "http://example.com/entity2"
|
||||
assert result[1].type == IRI
|
||||
assert isinstance(result[2], Term)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].is_uri is False
|
||||
assert result[2].type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
|
|
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results are deduplicated and limited
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
|
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -346,14 +346,14 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 4
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.is_uri]
|
||||
uri_results = [r for r in result if r.type == IRI]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.value for r in uri_results]
|
||||
uri_values = [r.iri for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.is_uri]
|
||||
literal_results = [r for r in result if not r.type == IRI]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
|
|
@ -486,7 +486,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -105,27 +105,27 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
uri_entity = "http://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
assert value.type == IRI
|
||||
|
||||
def test_create_value_https_uri(self, processor):
|
||||
"""Test create_value method for HTTPS URI entities"""
|
||||
uri_entity = "https://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
assert value.type == IRI
|
||||
|
||||
def test_create_value_literal(self, processor):
|
||||
"""Test create_value method for literal entities"""
|
||||
literal_entity = "literal_entity"
|
||||
value = processor.create_value(literal_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == literal_entity
|
||||
assert value.is_uri == False
|
||||
assert value.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
|
|
@ -165,11 +165,11 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
# Verify results
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].is_uri == True
|
||||
assert entities[0].type == IRI
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].is_uri == False
|
||||
assert entities[1].type == LITERAL
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].is_uri == True
|
||||
assert entities[2].type == IRI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -85,10 +86,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
assert hasattr(value, 'iri')
|
||||
assert value.iri == 'http://example.com/entity'
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == IRI
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -109,10 +110,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
assert hasattr(value, 'iri')
|
||||
assert value.iri == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == IRI
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -135,8 +136,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -428,14 +429,14 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
uri_values = [entity.iri for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryUserCollectionIsolation:
|
||||
|
|
@ -24,9 +24,9 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="test_object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="test_object"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -65,8 +65,8 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
|
@ -105,9 +105,9 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=Value(value="http://example.com/o", is_uri=True),
|
||||
o=Term(type=IRI, iri="http://example.com/o"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -185,8 +185,8 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="literal", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="literal"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
|
@ -265,7 +265,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="test_value", is_uri=False),
|
||||
o=Term(type=LITERAL, value="test_value"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
# Query without user/collection fields
|
||||
query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -385,7 +385,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -416,17 +416,17 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
assert len(result) == 2
|
||||
|
||||
# First triple (literal object)
|
||||
assert result[0].s.value == "http://example.com/s"
|
||||
assert result[0].s.is_uri == True
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].p.is_uri == True
|
||||
assert result[0].s.iri == "http://example.com/s"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "literal_value"
|
||||
assert result[0].o.is_uri == False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Second triple (URI object)
|
||||
assert result[1].s.value == "http://example.com/s"
|
||||
assert result[1].s.is_uri == True
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].p.is_uri == True
|
||||
assert result[1].o.value == "http://example.com/o"
|
||||
assert result[1].o.is_uri == True
|
||||
assert result[1].s.iri == "http://example.com/s"
|
||||
assert result[1].s.type == IRI
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].p.type == IRI
|
||||
assert result[1].o.iri == "http://example.com/o"
|
||||
assert result[1].o.type == IRI
|
||||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryUserCollectionIsolation:
|
||||
|
|
@ -24,21 +24,23 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="test_object", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="test_object"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src"
|
||||
"RETURN $src as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -63,23 +65,25 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=None
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest"
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
|
|
@ -88,13 +92,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
|
||||
# Verify SP query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest"
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -118,21 +123,23 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=Value(value="http://example.com/o", is_uri=True)
|
||||
o=Term(type=IRI, iri="http://example.com/o"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel"
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -156,23 +163,25 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest"
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
|
|
@ -194,20 +203,22 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="literal", is_uri=False)
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="literal"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src"
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -232,20 +243,22 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=None
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest"
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -270,19 +283,21 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="test_value", is_uri=False)
|
||||
o=Term(type=LITERAL, value="test_value"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel"
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -307,34 +322,37 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest"
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -355,9 +373,10 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
# Query without user/collection fields
|
||||
query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
|
@ -384,47 +403,48 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
# Mock some results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {
|
||||
"rel": "http://example.com/p1",
|
||||
"dest": "literal_value"
|
||||
}
|
||||
|
||||
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {
|
||||
"rel": "http://example.com/p2",
|
||||
"dest": "http://example.com/o"
|
||||
}
|
||||
|
||||
|
||||
# Return results for literal query, empty for node query
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], MagicMock(), MagicMock()), # Literal query
|
||||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
||||
# First triple (literal object)
|
||||
assert result[0].s.value == "http://example.com/s"
|
||||
assert result[0].s.is_uri == True
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].p.is_uri == True
|
||||
assert result[0].s.iri == "http://example.com/s"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "literal_value"
|
||||
assert result[0].o.is_uri == False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Second triple (URI object)
|
||||
assert result[1].s.value == "http://example.com/s"
|
||||
assert result[1].s.is_uri == True
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].p.is_uri == True
|
||||
assert result[1].o.value == "http://example.com/o"
|
||||
assert result[1].o.is_uri == True
|
||||
assert result[1].s.iri == "http://example.com/s"
|
||||
assert result[1].s.type == IRI
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].p.type == IRI
|
||||
assert result[1].o.iri == "http://example.com/o"
|
||||
assert result[1].o.type == IRI
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
"""
|
||||
Unit tests for Cassandra Objects GraphQL Query Processor
|
||||
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the GraphQL query processor including:
|
||||
- GraphQL schema generation from RowSchema
|
||||
- Query execution and validation
|
||||
- CQL translation logic
|
||||
- Schema configuration handling
|
||||
- Query execution using unified rows table
|
||||
- Name sanitization
|
||||
- GraphQL query execution
|
||||
- Message processing logic
|
||||
"""
|
||||
|
||||
|
|
@ -12,119 +13,91 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsGraphQLQueryLogic:
|
||||
"""Test business logic without external dependencies"""
|
||||
|
||||
def test_get_python_type_mapping(self):
|
||||
"""Test schema field type conversion to Python types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_python_type("string") == str
|
||||
assert processor.get_python_type("integer") == int
|
||||
assert processor.get_python_type("float") == float
|
||||
assert processor.get_python_type("boolean") == bool
|
||||
assert processor.get_python_type("timestamp") == str
|
||||
assert processor.get_python_type("date") == str
|
||||
assert processor.get_python_type("time") == str
|
||||
assert processor.get_python_type("uuid") == str
|
||||
|
||||
# Unknown type defaults to str
|
||||
assert processor.get_python_type("unknown_type") == str
|
||||
|
||||
def test_create_graphql_type_basic_fields(self):
|
||||
"""Test GraphQL type creation for basic field types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
description="Test table",
|
||||
fields=[
|
||||
Field(
|
||||
name="id",
|
||||
type="string",
|
||||
primary=True,
|
||||
required=True,
|
||||
description="Primary key"
|
||||
),
|
||||
Field(
|
||||
name="name",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Name field"
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
required=False,
|
||||
description="Optional age"
|
||||
),
|
||||
Field(
|
||||
name="active",
|
||||
type="boolean",
|
||||
required=False,
|
||||
description="Status flag"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test_table", schema)
|
||||
|
||||
# Verify type was created
|
||||
assert graphql_type is not None
|
||||
assert hasattr(graphql_type, '__name__')
|
||||
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
|
||||
class TestRowsGraphQLQueryLogic:
|
||||
"""Test business logic for unified table query implementation"""
|
||||
|
||||
def test_sanitize_name_cassandra_compatibility(self):
|
||||
"""Test name sanitization for Cassandra field names"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test field name sanitization (matches storage processor)
|
||||
|
||||
# Test field name sanitization (uses r_ prefix like storage processor)
|
||||
assert processor.sanitize_name("simple_field") == "simple_field"
|
||||
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
|
||||
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
|
||||
assert processor.sanitize_name("123_field") == "o_123_field"
|
||||
assert processor.sanitize_name("123_field") == "r_123_field"
|
||||
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
|
||||
assert processor.sanitize_name("special!@#chars") == "special___chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
|
||||
def test_sanitize_table_name(self):
|
||||
"""Test table name sanitization (always gets o_ prefix)"""
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Table names always get o_ prefix
|
||||
assert processor.sanitize_table("simple_table") == "o_simple_table"
|
||||
assert processor.sanitize_table("Table-Name") == "o_table_name"
|
||||
assert processor.sanitize_table("123table") == "o_123table"
|
||||
assert processor.sanitize_table("") == "o_"
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_find_matching_index_exact_match(self):
|
||||
"""Test finding matching index for exact match query"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Filter on indexed field should return match
|
||||
filters = {"category": "electronics"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is not None
|
||||
assert result[0] == "category"
|
||||
assert result[1] == ["electronics"]
|
||||
|
||||
# Filter on non-indexed field should return None
|
||||
filters = {"name": "test"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.graphql_types = {}
|
||||
processor.graphql_schema = None
|
||||
processor.config_key = "schema" # Set the config key
|
||||
processor.generate_graphql_schema = AsyncMock()
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Create test config
|
||||
schema_config = {
|
||||
"schema": {
|
||||
|
|
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
||||
# Verify fields
|
||||
id_field = next(f for f in schema.fields if f.name == "id")
|
||||
assert id_field.primary is True
|
||||
# The field should have been created correctly from JSON
|
||||
# Let's test what we can verify - that the field has the right attributes
|
||||
assert hasattr(id_field, 'required') # Has the required attribute
|
||||
assert hasattr(id_field, 'primary') # Has the primary attribute
|
||||
|
||||
|
||||
email_field = next(f for f in schema.fields if f.name == "email")
|
||||
assert email_field.indexed is True
|
||||
|
||||
|
||||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify GraphQL schema regeneration was called
|
||||
processor.generate_graphql_schema.assert_called_once()
|
||||
|
||||
def test_cql_query_building_basic(self):
|
||||
"""Test basic CQL query construction"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to capture the query
|
||||
mock_result = []
|
||||
processor.session.execute.return_value = mock_result
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string", indexed=True),
|
||||
Field(name="status", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Test query building
|
||||
asyncio = pytest.importorskip("asyncio")
|
||||
|
||||
async def run_test():
|
||||
await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="test_table",
|
||||
row_schema=schema,
|
||||
filters={"name": "John", "invalid_filter": "ignored"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Run the async test
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(run_test())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Verify Cassandra connection and query execution
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify the query structure (can't easily test exact query without complex mocking)
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0] # First positional argument is the query
|
||||
params = call_args[0][1] # Second positional argument is parameters
|
||||
|
||||
# Basic query structure checks
|
||||
assert "SELECT * FROM test_user.o_test_table" in query
|
||||
assert "WHERE" in query
|
||||
assert "collection = %s" in query
|
||||
assert "LIMIT 10" in query
|
||||
|
||||
# Parameters should include collection and name filter
|
||||
assert "test_collection" in params
|
||||
assert "John" in params
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
|
|
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value'] # keyword argument
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
|
||||
# Verify result structure
|
||||
assert "data" in result
|
||||
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
|
||||
|
|
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error instead of MagicMock
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
class MockError:
|
||||
def __init__(self, message, path, extensions):
|
||||
self.message = message
|
||||
self.path = path
|
||||
self.extensions = extensions
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
mock_error = MockError(
|
||||
message="Field 'invalid_field' doesn't exist",
|
||||
path=["customers", "0", "invalid_field"],
|
||||
extensions={"code": "FIELD_NOT_FOUND"}
|
||||
)
|
||||
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify error handling
|
||||
assert "errors" in result
|
||||
assert len(result["errors"]) == 1
|
||||
|
||||
|
||||
error = result["errors"][0]
|
||||
assert error["message"] == "Field 'invalid_field' doesn't exist"
|
||||
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
|
||||
assert error["path"] == ["customers", "0", "invalid_field"]
|
||||
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
|
||||
|
||||
def test_schema_generation_basic_structure(self):
|
||||
"""Test basic GraphQL schema generation structure"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Test individual type creation (avoiding the full schema generation which has annotation issues)
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify type was created
|
||||
assert len(processor.graphql_types) == 1
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processing_success(self):
|
||||
"""Test successful message processing flow"""
|
||||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock successful query result
|
||||
processor.execute_graphql_query.return_value = {
|
||||
"data": {"customers": [{"id": "1", "name": "John"}]},
|
||||
"errors": [],
|
||||
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
|
||||
"extensions": {}
|
||||
}
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None
|
||||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-123"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
query='{ customers { id name } }',
|
||||
|
|
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is None
|
||||
assert '"customers"' in response_call.data # JSON encoded
|
||||
assert len(response_call.errors) == 0
|
||||
|
|
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock query execution error
|
||||
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
|
|
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
|
|||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-456"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify error response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify error response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is not None
|
||||
assert response_call.error.type == "objects-query-error"
|
||||
assert response_call.error.type == "rows-query-error"
|
||||
assert "No schema available" in response_call.error.message
|
||||
assert response_call.data is None
|
||||
|
||||
|
||||
class TestCQLQueryGeneration:
|
||||
"""Test CQL query generation logic in isolation"""
|
||||
|
||||
def test_partition_key_inclusion(self):
|
||||
"""Test that collection is always included in queries"""
|
||||
class TestUnifiedTableQueries:
|
||||
"""Test queries against the unified rows table"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_index_match(self):
|
||||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Mock the query building (simplified version)
|
||||
keyspace = processor.sanitize_name("test_user")
|
||||
table = processor.sanitize_table("test_table")
|
||||
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
where_clauses = ["collection = %s"]
|
||||
|
||||
assert "collection = %s" in where_clauses
|
||||
assert keyspace == "test_user"
|
||||
assert table == "o_test_table"
|
||||
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row = MagicMock()
|
||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||
processor.session.execute.return_value = [mock_row]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"category": "electronics"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify Cassandra was connected and queried
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify query structure - should query unified rows table
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
params = call_args[0][1]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
assert "index_value = %s" in query
|
||||
|
||||
assert params[0] == "test_collection"
|
||||
assert params[1] == "products"
|
||||
assert params[2] == "category"
|
||||
assert params[3] == ["electronics"]
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "123"
|
||||
assert results[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_without_index_match(self):
|
||||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row1 = MagicMock()
|
||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||
mock_row2 = MagicMock()
|
||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
||||
processor.session.execute.return_value = [mock_row1, mock_row2]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="price", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"name": "Product A"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Query should use ALLOW FILTERING for scan
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
|
||||
assert "ALLOW FILTERING" in query
|
||||
|
||||
# Should post-filter results
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "Product A"
|
||||
|
||||
|
||||
class TestFilterMatching:
|
||||
"""Test filter matching logic"""
|
||||
|
||||
def test_matches_filters_exact_match(self):
|
||||
"""Test exact match filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active", "name": "test"}
|
||||
assert processor._matches_filters(row, {"status": "active"}, schema) is True
|
||||
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
|
||||
|
||||
def test_matches_filters_comparison_operators(self):
|
||||
"""Test comparison operators in filters"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
|
||||
|
||||
row = {"price": "100.0"}
|
||||
|
||||
# Greater than
|
||||
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
|
||||
|
||||
# Less than
|
||||
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
|
||||
|
||||
# Greater than or equal
|
||||
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
|
||||
|
||||
# Less than or equal
|
||||
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
|
||||
|
||||
def test_matches_filters_contains(self):
|
||||
"""Test contains filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
|
||||
|
||||
row = {"description": "A great product for everyone"}
|
||||
|
||||
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
|
||||
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
|
||||
|
||||
def test_matches_filters_in_list(self):
|
||||
"""Test in-list filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active"}
|
||||
|
||||
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
|
||||
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
|
||||
|
||||
|
||||
class TestIndexedFieldFiltering:
|
||||
"""Test that only indexed or primary key fields can be directly filtered"""
|
||||
|
||||
def test_indexed_field_filtering(self):
|
||||
"""Test that only indexed or primary key fields can be filtered"""
|
||||
# Create schema with mixed field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="normal_field", type="string", indexed=False),
|
||||
Field(name="another_field", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
filters = {
|
||||
"id": "test123", # Primary key - should be included
|
||||
"indexed_field": "value", # Indexed - should be included
|
||||
"normal_field": "ignored", # Not indexed - should be ignored
|
||||
"another_field": "also_ignored" # Not indexed - should be ignored
|
||||
}
|
||||
|
||||
|
||||
# Simulate the filtering logic from the processor
|
||||
valid_filters = []
|
||||
for field_name, value in filters.items():
|
||||
|
|
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
|
|||
schema_field = next((f for f in schema.fields if f.name == field_name), None)
|
||||
if schema_field and (schema_field.indexed or schema_field.primary):
|
||||
valid_filters.append((field_name, value))
|
||||
|
||||
|
||||
# Only id and indexed_field should be included
|
||||
assert len(valid_filters) == 2
|
||||
field_names = [f[0] for f in valid_filters]
|
||||
|
|
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
|
|||
assert "indexed_field" in field_names
|
||||
assert "normal_field" not in field_names
|
||||
assert "another_field" not in field_names
|
||||
|
||||
|
||||
class TestGraphQLSchemaGeneration:
|
||||
"""Test GraphQL schema generation in detail"""
|
||||
|
||||
def test_field_type_annotations(self):
|
||||
"""Test that GraphQL types have correct field annotations"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create schema with various field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", required=True, primary=True),
|
||||
Field(name="count", type="integer", required=True),
|
||||
Field(name="price", type="float", required=False),
|
||||
Field(name="active", type="boolean", required=False),
|
||||
Field(name="optional_text", type="string", required=False)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test", schema)
|
||||
|
||||
# Verify type was created successfully
|
||||
assert graphql_type is not None
|
||||
|
||||
def test_basic_type_creation(self):
|
||||
"""Test that GraphQL types are created correctly"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create GraphQL type directly
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify customer type was created
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
|
@ -5,8 +5,8 @@ Tests for Cassandra triples query service
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
|
|
@ -21,94 +21,101 @@ class TestCassandraQueryProcessor:
|
|||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
"""Test create_term with HTTP URI"""
|
||||
result = create_term("http://example.com/resource")
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_https_uri(self, processor):
|
||||
"""Test create_term with HTTPS URI"""
|
||||
result = create_term("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_term_with_literal(self, processor):
|
||||
"""Test create_term with literal value"""
|
||||
result = create_term("just a literal string")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_empty_string(self, processor):
|
||||
"""Test create_term with empty string"""
|
||||
result = create_term("")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_partial_uri(self, processor):
|
||||
"""Test create_term with string that looks like URI but isn't complete"""
|
||||
result = create_term("http")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_ftp_uri(self, processor):
|
||||
"""Test create_term with FTP URI (should not be detected as URI)"""
|
||||
result = create_term("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_spo_query(self, mock_kg_class):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
# Setup mock TrustGraph via factory function
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results (with mock graph attribute)
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user'
|
||||
)
|
||||
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
|
|
@ -143,154 +150,174 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_kg_class):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
# Setup mock TrustGraph via factory function
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_kg_class):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_kg_class):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_kg_class):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=75
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_kg_class):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
|
|
@ -299,9 +326,9 @@ class TestCassandraQueryProcessor:
|
|||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
|
|
@ -372,37 +399,44 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o, g) quad pattern, some values may be\nnull. Output is a list of quads.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_kg_class):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_username='authuser',
|
||||
cassandra_password='authpass'
|
||||
)
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='test_user',
|
||||
username='authuser',
|
||||
|
|
@ -410,128 +444,154 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_kg_class):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_table_switching(self, mock_kg_class):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
# Setup mock results for both instances
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_kg_class):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_kg_class):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result1.g = ''
|
||||
mock_result1.otype = None
|
||||
mock_result1.dtype = None
|
||||
mock_result1.lang = None
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_result2.g = ''
|
||||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
|
|
@ -541,16 +601,20 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
"""Test cases for multi-table performance optimizations in query service"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_get_po_query_optimization(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_get_po_query_optimization(self, mock_kg_class):
|
||||
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
|
@ -560,8 +624,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=50
|
||||
)
|
||||
|
||||
|
|
@ -569,7 +633,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', limit=50
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
|
@ -578,16 +642,20 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_get_os_query_optimization(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_get_os_query_optimization(self, mock_kg_class):
|
||||
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
|
@ -596,9 +664,9 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=25
|
||||
)
|
||||
|
||||
|
|
@ -606,7 +674,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', limit=25
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
|
@ -615,13 +683,13 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_all_query_patterns_use_correct_tables(self, mock_kg_class):
|
||||
"""Test that all query patterns route to their optimal tables"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
|
|
@ -655,9 +723,9 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value=s, is_uri=False) if s else None,
|
||||
p=Value(value=p, is_uri=False) if p else None,
|
||||
o=Value(value=o, is_uri=False) if o else None,
|
||||
s=Term(type=LITERAL, value=s) if s else None,
|
||||
p=Term(type=LITERAL, value=p) if p else None,
|
||||
o=Term(type=LITERAL, value=o) if o else None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
|
@ -687,19 +755,23 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
# Mode is determined in KnowledgeGraph initialization
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_performance_critical_po_query_no_filtering(self, mock_kg_class):
|
||||
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple subjects for the same predicate-object pair
|
||||
mock_results = []
|
||||
for i in range(5):
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = f'subject_{i}'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
|
|
@ -711,8 +783,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
user='large_dataset_user',
|
||||
collection='massive_collection',
|
||||
s=None,
|
||||
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
|
||||
o=Value(value='http://example.com/Person', is_uri=True),
|
||||
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
|
||||
o=Term(type=IRI, iri='http://example.com/Person'),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -723,14 +795,15 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
g=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
# Verify all results were returned
|
||||
assert len(result) == 5
|
||||
for i, triple in enumerate(result):
|
||||
assert triple.s.value == f'subject_{i}'
|
||||
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.is_uri is True
|
||||
assert triple.o.value == 'http://example.com/Person'
|
||||
assert triple.o.is_uri is True
|
||||
assert triple.s.value == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
|
||||
assert triple.o.type == IRI
|
||||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.falkordb.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestFalkorDBQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestFalkorDBQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
|
|
@ -125,9 +125,9 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -138,8 +138,8 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -166,8 +166,8 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -211,9 +211,9 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -224,12 +224,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -256,7 +256,7 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
@ -269,13 +269,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -302,8 +302,8 @@ class TestFalkorDBQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -314,12 +314,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -347,7 +347,7 @@ class TestFalkorDBQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -359,13 +359,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -393,7 +393,7 @@ class TestFalkorDBQueryProcessor:
|
|||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -404,12 +404,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -449,13 +449,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -476,7 +476,7 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestMemgraphQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
|
|
@ -124,9 +124,9 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +137,8 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -166,8 +166,8 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -212,9 +212,9 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -225,12 +225,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -258,7 +258,7 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
@ -271,13 +271,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -305,8 +305,8 @@ class TestMemgraphQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -317,12 +317,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -351,7 +351,7 @@ class TestMemgraphQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -363,13 +363,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -398,7 +398,7 @@ class TestMemgraphQueryProcessor:
|
|||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -409,12 +409,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -455,13 +455,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -480,7 +480,7 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestNeo4jQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
|
|
@ -124,9 +124,9 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +137,8 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
|
|
@ -166,8 +166,8 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -225,13 +225,13 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -250,12 +250,12 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
RowsQueryRequest, RowsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
|
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||
errors=None,
|
||||
|
|
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
|
|||
# Verify objects query service was called correctly
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
|
|
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
|
|||
assert response.error is not None
|
||||
assert "empty GraphQL query" in response.error.message
|
||||
|
||||
async def test_objects_query_service_error(self, processor):
|
||||
async def test_rows_query_service_error(self, processor):
|
||||
"""Test handling of objects query service errors"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
|
|
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service error
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
|
||||
data=None,
|
||||
errors=None,
|
||||
|
|
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
|
|||
response = response_call[0][0]
|
||||
|
||||
assert response.error is not None
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "Rows query service error" in response.error.message
|
||||
assert "Table 'customers' not found" in response.error.message
|
||||
|
||||
async def test_graphql_errors_handling(self, processor):
|
||||
|
|
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None,
|
||||
errors=graphql_errors,
|
||||
|
|
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
|
||||
errors=None
|
||||
|
|
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
|
|||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None, # Null data
|
||||
errors=None,
|
||||
|
|
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
|
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
|
|||
assert processor.cassandra_password is None
|
||||
|
||||
|
||||
class TestObjectsWriterConfiguration:
|
||||
class TestRowsWriterConfiguration:
|
||||
"""Test Cassandra configuration in objects writer processor."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_environment_variable_configuration(self, mock_cluster):
|
||||
"""Test processor picks up configuration from environment variables."""
|
||||
env_vars = {
|
||||
|
|
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
|
||||
assert processor.cassandra_username == 'obj-env-user'
|
||||
assert processor.cassandra_password == 'obj-env-pass'
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
|
||||
"""Test that Cassandra connection uses hosts list correctly."""
|
||||
env_vars = {
|
||||
|
|
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify cluster was called with hosts list
|
||||
|
|
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
|
|||
assert 'contact_points' in call_args.kwargs
|
||||
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
|
||||
"""Test authentication is configured when credentials are provided."""
|
||||
env_vars = {
|
||||
|
|
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify auth provider was created with correct credentials
|
||||
|
|
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
|
|||
def test_objects_writer_add_args(self):
|
||||
"""Test that objects writer adds standard Cassandra arguments."""
|
||||
import argparse
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ObjectsWriter.add_args(parser)
|
||||
RowsWriter.add_args(parser)
|
||||
|
||||
# Parse empty args to check that arguments exist
|
||||
args = parser.parse_args([])
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import Value, EntityEmbeddings
|
||||
from trustgraph.schema import Term, EntityEmbeddings, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||
|
|
@ -22,11 +22,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
# Create test entities with embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity1', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value='literal entity', is_uri=False),
|
||||
entity=Term(type=LITERAL, value='literal entity'),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
|
@ -84,7 +84,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -136,7 +136,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -155,7 +155,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -174,15 +174,15 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/valid', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/valid'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
|
@ -217,7 +217,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -236,7 +236,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
|
|
@ -269,11 +269,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/uri_entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
literal_entity = EntityEmbeddings(
|
||||
entity=Value(value='literal entity text', is_uri=False),
|
||||
entity=Term(type=LITERAL, value='literal entity text'),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
|
|
@ -67,7 +68,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
|
@ -120,11 +122,13 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.entity.type = IRI
|
||||
mock_entity1.entity.iri = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.entity.type = IRI
|
||||
mock_entity2.entity.iri = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
|
@ -179,7 +183,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'multi_vector_entity'
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
|
|
@ -231,11 +236,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.type = LITERAL
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity.value = None # None value
|
||||
mock_entity_none.entity = None # None entity
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch, call
|
|||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triples, Triple, Value, Metadata
|
||||
from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
|
|
@ -60,9 +60,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
)
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal_value", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal_value")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
|
|
@ -128,9 +128,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata = Metadata(id="test-id")
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="http://example.com/object", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=IRI, iri="http://example.com/object")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
|
|
@ -170,8 +170,8 @@ class TestNeo4jUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None
|
||||
)
|
||||
|
||||
|
|
@ -254,9 +254,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.com/user1/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="user1_data", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/user1/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user1_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -265,9 +265,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.com/user2/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="user2_data", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/user2/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user2_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -429,9 +429,9 @@ class TestNeo4jUserCollectionRegression:
|
|||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=shared_uri, is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="user1_value", is_uri=False)
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user1_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -440,9 +440,9 @@ class TestNeo4jUserCollectionRegression:
|
|||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=shared_uri, is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="user2_value", is_uri=False)
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user2_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,533 +0,0 @@
|
|||
"""
|
||||
Unit tests for Cassandra Object Storage Processor
|
||||
|
||||
Tests the business logic of the object storage processor including:
|
||||
- Schema configuration handling
|
||||
- Type conversions
|
||||
- Name sanitization
|
||||
- Table structure generation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageLogic:
|
||||
"""Test business logic without FlowProcessor dependencies"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns (back to original logic)
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
|
||||
def test_get_cassandra_type(self):
|
||||
"""Test field type conversion to Cassandra types"""
|
||||
processor = MagicMock()
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_cassandra_type("string") == "text"
|
||||
assert processor.get_cassandra_type("boolean") == "boolean"
|
||||
assert processor.get_cassandra_type("timestamp") == "timestamp"
|
||||
assert processor.get_cassandra_type("uuid") == "uuid"
|
||||
|
||||
# Integer types with size hints
|
||||
assert processor.get_cassandra_type("integer", size=2) == "int"
|
||||
assert processor.get_cassandra_type("integer", size=8) == "bigint"
|
||||
|
||||
# Float types with size hints
|
||||
assert processor.get_cassandra_type("float", size=2) == "float"
|
||||
assert processor.get_cassandra_type("float", size=8) == "double"
|
||||
|
||||
# Unknown type defaults to text
|
||||
assert processor.get_cassandra_type("unknown_type") == "text"
|
||||
|
||||
def test_convert_value(self):
|
||||
"""Test value conversion for different field types"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Integer conversions
|
||||
assert processor.convert_value("123", "integer") == 123
|
||||
assert processor.convert_value(123.5, "integer") == 123
|
||||
assert processor.convert_value(None, "integer") is None
|
||||
|
||||
# Float conversions
|
||||
assert processor.convert_value("123.45", "float") == 123.45
|
||||
assert processor.convert_value(123, "float") == 123.0
|
||||
|
||||
# Boolean conversions
|
||||
assert processor.convert_value("true", "boolean") is True
|
||||
assert processor.convert_value("false", "boolean") is False
|
||||
assert processor.convert_value("1", "boolean") is True
|
||||
assert processor.convert_value("0", "boolean") is False
|
||||
assert processor.convert_value("yes", "boolean") is True
|
||||
assert processor.convert_value("no", "boolean") is False
|
||||
|
||||
# String conversions
|
||||
assert processor.convert_value(123, "string") == "123"
|
||||
assert processor.convert_value(True, "string") == "True"
|
||||
|
||||
def test_table_creation_cql_generation(self):
|
||||
"""Test CQL generation for table creation"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="customer_records",
|
||||
description="Test customer schema",
|
||||
fields=[
|
||||
Field(
|
||||
name="customer_id",
|
||||
type="string",
|
||||
size=50,
|
||||
primary=True,
|
||||
required=True,
|
||||
indexed=False
|
||||
),
|
||||
Field(
|
||||
name="email",
|
||||
type="string",
|
||||
size=100,
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
size=4,
|
||||
required=False,
|
||||
indexed=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "customer_records", schema)
|
||||
|
||||
# Verify keyspace was ensured (check that it was added to known_keyspaces)
|
||||
assert "test_user" in processor.known_keyspaces
|
||||
|
||||
# Check the CQL that was executed (first call should be table creation)
|
||||
all_calls = processor.session.execute.call_args_list
|
||||
table_creation_cql = all_calls[0][0][0] # First call
|
||||
|
||||
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
|
||||
assert "collection text" in table_creation_cql
|
||||
assert "customer_id text" in table_creation_cql
|
||||
assert "email text" in table_creation_cql
|
||||
assert "age int" in table_creation_cql
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
|
||||
|
||||
def test_table_creation_without_primary_key(self):
|
||||
"""Test table creation when no primary key is defined"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema without primary key
|
||||
schema = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "events", schema)
|
||||
|
||||
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
|
||||
executed_cql = processor.session.execute.call_args[0][0]
|
||||
assert "synthetic_id uuid" in executed_cql
|
||||
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "balance",
|
||||
"type": "float",
|
||||
"size": 8
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
# Note: Field.required always returns False due to Pulsar schema limitations
|
||||
# The actual required value is tracked during schema parsing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_logic(self):
|
||||
"""Test the logic for processing ExtractedObject"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "456"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
|
||||
|
||||
# Verify insert was executed (keyspace normal, table with o_ prefix)
|
||||
processor.session.execute.assert_called_once()
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert values[0] == "test_collection" # collection value
|
||||
assert values[1] == "123" # id value (from values[0])
|
||||
assert values[2] == 456 # converted integer value (from values[0])
|
||||
|
||||
def test_secondary_index_creation(self):
|
||||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
if keyspace not in processor.known_tables:
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
description="Product catalog",
|
||||
fields=[
|
||||
Field(name="product_id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=30, indexed=True),
|
||||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check index creation calls (table has o_ prefix, fields don't)
|
||||
calls = processor.session.execute.call_args_list
|
||||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
||||
assert len(index_calls) == 2
|
||||
assert any("o_products_category_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic in Cassandra storage"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing_logic(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
description="Test batch schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First", "value": "100"},
|
||||
{"id": "002", "name": "Second", "value": "200"},
|
||||
{"id": "003", "name": "Third", "value": "300"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span="batch source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
# Process batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured once
|
||||
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
|
||||
|
||||
# Verify 3 separate insert calls (one per batch item)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert call
|
||||
calls = processor.session.execute.call_args_list
|
||||
for i, call in enumerate(calls):
|
||||
insert_cql = call[0][0]
|
||||
values = call[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
|
||||
# Check values for each batch item
|
||||
assert values[0] == "batch_collection" # collection
|
||||
assert values[1] == f"00{i+1}" # id from batch item i
|
||||
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
|
||||
assert values[3] == (i+1) * 100 # converted integer value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing_logic(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="empty source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
# Process empty batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_item_batch_processing_logic(self):
|
||||
"""Test processing of single-item batch (backward compatibility)"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"single_schema": RowSchema(
|
||||
name="single_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="single-001",
|
||||
user="test_user",
|
||||
collection="single_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="single_schema",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with one item
|
||||
confidence=0.8,
|
||||
source_span="single source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = single_batch_obj
|
||||
|
||||
# Process single-item batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify exactly one insert call
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_single_schema" in insert_cql
|
||||
assert values[0] == "single_collection" # collection
|
||||
assert values[1] == "single-1" # id value
|
||||
assert values[2] == "single data" # data value
|
||||
|
||||
def test_batch_value_conversion_logic(self):
|
||||
"""Test value conversion works correctly for batch items"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Test various conversion scenarios that would occur in batch processing
|
||||
test_cases = [
|
||||
# Integer conversions for batch items
|
||||
("123", "integer", 123),
|
||||
("456", "integer", 456),
|
||||
("789", "integer", 789),
|
||||
# Float conversions for batch items
|
||||
("12.5", "float", 12.5),
|
||||
("34.7", "float", 34.7),
|
||||
# Boolean conversions for batch items
|
||||
("true", "boolean", True),
|
||||
("false", "boolean", False),
|
||||
("1", "boolean", True),
|
||||
("0", "boolean", False),
|
||||
# String conversions for batch items
|
||||
(123, "string", "123"),
|
||||
(45.6, "string", "45.6"),
|
||||
]
|
||||
|
||||
for input_val, field_type, expected_output in test_cases:
|
||||
result = processor.convert_value(input_val, field_type)
|
||||
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"
|
||||
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant row embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key='test-api-key'
|
||||
)
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key=None
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_sanitize_name(self, mock_qdrant_client):
|
||||
"""Test name sanitization for Qdrant collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Test basic sanitization
|
||||
assert processor.sanitize_name("simple") == "simple"
|
||||
assert processor.sanitize_name("with-dash") == "with_dash"
|
||||
assert processor.sanitize_name("with.dot") == "with_dot"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
|
||||
# Test numeric prefix handling
|
||||
assert processor.sanitize_name("123start") == "r_123start"
|
||||
assert processor.sanitize_name("_underscore") == "r__underscore"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_get_collection_name(self, mock_qdrant_client):
|
||||
"""Test Qdrant collection name generation"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
collection_name = processor.get_collection_name(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="customer_data",
|
||||
dimension=384
|
||||
)
|
||||
|
||||
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection creates a new collection when needed"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection skips creation when collection exists"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection uses cache for previously created collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing basic row embeddings message"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Create embeddings message
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='customer_id',
|
||||
index_value=['CUST001'],
|
||||
text='CUST001',
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
# Mock message wrapper
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['index_name'] == 'customer_id'
|
||||
assert point.payload['index_value'] == ['CUST001']
|
||||
assert point.payload['text'] == 'CUST001'
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing embeddings with multiple vectors"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with multiple vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='name',
|
||||
index_value=['John Doe'],
|
||||
text='John Doe',
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='people',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||
"""Test that embeddings with no vectors are skipped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with no vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[] # Empty vectors
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for empty vectors
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[[0.1, 0.2]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for unknown collection
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection(self, mock_qdrant_client):
|
||||
"""Test deleting all collections for a user/collection"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
# Mock collections list
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||
mock_coll3 = MagicMock()
|
||||
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_user', 'test_collection')
|
||||
|
||||
# Should delete only the matching collections
|
||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
"""Test deleting collections for a specific schema"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
await processor.delete_collection_schema(
|
||||
'test_user', 'test_collection', 'customers'
|
||||
)
|
||||
|
||||
# Should only delete the customers schema collection
|
||||
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
"""
|
||||
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the row storage processor including:
|
||||
- Schema configuration handling
|
||||
- Name sanitization
|
||||
- Unified table structure
|
||||
- Index management
|
||||
- Row storage with multi-index support
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
|
||||
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
# Schema with primary and indexed fields
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed fields
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names # Not indexed
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_get_index_names_no_indexes(self):
|
||||
"""Test schema with no indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="no_index_schema",
|
||||
fields=[
|
||||
Field(name="data1", type="string"),
|
||||
Field(name="data2", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
assert len(index_names) == 0
|
||||
|
||||
def test_build_index_value(self):
|
||||
"""Test building index values from row data"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
|
||||
|
||||
# Single field index
|
||||
result = processor.build_index_value(value_map, "id")
|
||||
assert result == ["123"]
|
||||
|
||||
result = processor.build_index_value(value_map, "category")
|
||||
assert result == ["electronics"]
|
||||
|
||||
# Missing field returns empty string
|
||||
result = processor.build_index_value(value_map, "missing")
|
||||
assert result == [""]
|
||||
|
||||
def test_build_index_value_composite(self):
|
||||
"""Test building composite index values"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
|
||||
|
||||
# Composite index (comma-separated field names)
|
||||
result = processor.build_index_value(value_map, "region,category")
|
||||
assert result == ["us-west", "electronics"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"type": "string",
|
||||
"indexed": True
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_stores_data_map(self):
|
||||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "test_data"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert was executed
|
||||
processor.session.execute.assert_called()
|
||||
insert_call = processor.session.execute.call_args
|
||||
insert_cql = insert_call[0][0]
|
||||
values = insert_call[0][1]
|
||||
|
||||
# Verify using unified rows table
|
||||
assert "INSERT INTO test_user.rows" in insert_cql
|
||||
|
||||
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||
assert values[0] == "test_collection" # collection
|
||||
assert values[1] == "test_schema" # schema_name
|
||||
assert values[2] == "id" # index_name (primary key field)
|
||||
assert values[3] == ["123"] # index_value as list
|
||||
assert values[4] == {"id": "123", "value": "test_data"} # data map
|
||||
assert values[5] == "" # source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_multiple_indexes(self):
|
||||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="multi_index_schema",
|
||||
values=[{"id": "123", "category": "electronics", "status": "active"}],
|
||||
confidence=0.9,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check that different index_names were used
|
||||
index_names_used = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
index_names_used.add(values[2]) # index_name is 3rd value
|
||||
|
||||
assert index_names_used == {"id", "category", "status"}
|
||||
|
||||
|
||||
class TestRowsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic for unified table implementation"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First"},
|
||||
{"id": "002", "name": "Second"},
|
||||
{"id": "003", "name": "Third"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert has different id
|
||||
ids_inserted = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
ids_inserted.add(tuple(values[3])) # index_value is 4th value
|
||||
|
||||
assert ids_inserted == {("001",), ("002",), ("003",)}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestUnifiedTableStructure:
|
||||
"""Test the unified rows table structure"""
|
||||
|
||||
def test_ensure_tables_creates_unified_structure(self):
|
||||
"""Test that ensure_tables creates the unified rows table"""
|
||||
processor = MagicMock()
|
||||
processor.known_keyspaces = {"test_user"}
|
||||
processor.tables_initialized = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should have 2 calls: create rows table + create row_partitions table
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Check rows table creation
|
||||
rows_cql = processor.session.execute.call_args_list[0][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
|
||||
assert "collection text" in rows_cql
|
||||
assert "schema_name text" in rows_cql
|
||||
assert "index_name text" in rows_cql
|
||||
assert "index_value frozen<list<text>>" in rows_cql
|
||||
assert "data map<text, text>" in rows_cql
|
||||
assert "source text" in rows_cql
|
||||
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
|
||||
|
||||
# Check row_partitions table creation
|
||||
partitions_cql = processor.session.execute.call_args_list[1][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
|
||||
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
|
||||
|
||||
# Verify keyspace added to initialized set
|
||||
assert "test_user" in processor.tables_initialized
|
||||
|
||||
def test_ensure_tables_idempotent(self):
|
||||
"""Test that ensure_tables is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.tables_initialized = {"test_user"} # Already initialized
|
||||
processor.session = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should not execute any CQL since already initialized
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestPartitionRegistration:
|
||||
"""Test partition registration for tracking what's stored"""
|
||||
|
||||
def test_register_partitions(self):
|
||||
"""Test registering partitions for a collection/schema pair"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Verify cache was updated
|
||||
assert ("test_collection", "test_schema") in processor.registered_partitions
|
||||
|
||||
def test_register_partitions_idempotent(self):
|
||||
"""Test that partition registration is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
|
||||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -6,7 +6,8 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Triple, LITERAL, IRI
|
||||
from trustgraph.direct.cassandra_kg import DEFAULT_GRAPH
|
||||
|
||||
|
||||
class TestCassandraStorageProcessor:
|
||||
|
|
@ -86,29 +87,29 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_username == 'new-user' # Only cassandra_* params work
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_switching_with_auth(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_switching_with_auth(self, mock_kg_class):
|
||||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
cassandra_username='testuser',
|
||||
cassandra_password='testpass'
|
||||
)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was called with auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user1',
|
||||
username='testuser',
|
||||
|
|
@ -117,128 +118,150 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_switching_without_auth(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_switching_without_auth(self, mock_kg_class):
|
||||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was called without auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_reuse_when_same(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_reuse_when_same(self, mock_kg_class):
|
||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_triple_insertion(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_triple_insertion(self, mock_kg_class):
|
||||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
|
||||
# Create mock triples with proper Term structure
|
||||
triple1 = MagicMock()
|
||||
triple1.s.type = LITERAL
|
||||
triple1.s.value = 'subject1'
|
||||
triple1.s.datatype = ''
|
||||
triple1.s.language = ''
|
||||
triple1.p.type = LITERAL
|
||||
triple1.p.value = 'predicate1'
|
||||
triple1.o.type = LITERAL
|
||||
triple1.o.value = 'object1'
|
||||
|
||||
triple1.o.datatype = ''
|
||||
triple1.o.language = ''
|
||||
triple1.g = None
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.type = LITERAL
|
||||
triple2.s.value = 'subject2'
|
||||
triple2.s.datatype = ''
|
||||
triple2.s.language = ''
|
||||
triple2.p.type = LITERAL
|
||||
triple2.p.value = 'predicate2'
|
||||
triple2.o.type = LITERAL
|
||||
triple2.o.value = 'object2'
|
||||
|
||||
triple2.o.datatype = ''
|
||||
triple2.o.language = ''
|
||||
triple2.g = None
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify both triples were inserted
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1')
|
||||
mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2')
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_kg_class):
|
||||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
|
|
@ -326,92 +349,104 @@ class TestCassandraStorageProcessor:
|
|||
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_kg_class):
|
||||
"""Test table switching when different tables are used in sequence"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
assert mock_trustgraph.call_count == 2
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_kg_class):
|
||||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
|
||||
# Create triple with special characters and proper Term structure
|
||||
triple = MagicMock()
|
||||
triple.s.type = LITERAL
|
||||
triple.s.value = 'subject with spaces & symbols'
|
||||
triple.s.datatype = ''
|
||||
triple.s.language = ''
|
||||
triple.p.type = LITERAL
|
||||
triple.p.value = 'predicate:with/colons'
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
|
||||
|
||||
triple.o.datatype = ''
|
||||
triple.o.language = ''
|
||||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
'object with "quotes" and unicode: ñáéíóú'
|
||||
'object with "quotes" and unicode: ñáéíóú',
|
||||
g=DEFAULT_GRAPH,
|
||||
otype='l',
|
||||
dtype='',
|
||||
lang=''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
|
|
@ -422,12 +457,12 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test cases for multi-table performance optimizations"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_legacy_mode_uses_single_table(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_legacy_mode_uses_single_table(self, mock_kg_class):
|
||||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -440,16 +475,15 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance uses legacy mode
|
||||
kg_instance = mock_trustgraph.return_value
|
||||
assert kg_instance is not None
|
||||
assert mock_tg_instance is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_optimized_mode_uses_multi_table(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_optimized_mode_uses_multi_table(self, mock_kg_class):
|
||||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -462,24 +496,31 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance is in optimized mode
|
||||
kg_instance = mock_trustgraph.return_value
|
||||
assert kg_instance is not None
|
||||
assert mock_tg_instance is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_batch_write_consistency(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_batch_write_consistency(self, mock_kg_class):
|
||||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test triple
|
||||
# Create test triple with proper Term structure
|
||||
triple = MagicMock()
|
||||
triple.s.type = LITERAL
|
||||
triple.s.value = 'test_subject'
|
||||
triple.s.datatype = ''
|
||||
triple.s.language = ''
|
||||
triple.p.type = LITERAL
|
||||
triple.p.value = 'test_predicate'
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'test_object'
|
||||
triple.o.datatype = ''
|
||||
triple.o.language = ''
|
||||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
|
|
@ -490,7 +531,8 @@ class TestCassandraPerformanceOptimizations:
|
|||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object'
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
||||
def test_environment_variable_controls_mode(self):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.falkordb.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestFalkorDBStorageProcessor:
|
||||
|
|
@ -22,9 +22,9 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -183,9 +183,9 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=IRI, iri='http://example.com/object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -269,14 +269,14 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object1')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
@ -337,14 +337,14 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphStorageProcessor:
|
||||
|
|
@ -22,9 +22,9 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -231,9 +231,9 @@ class TestMemgraphStorageProcessor:
|
|||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=IRI, iri='http://example.com/object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
|
|
@ -265,9 +265,9 @@ class TestMemgraphStorageProcessor:
|
|||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
|
|
@ -347,14 +347,14 @@ class TestMemgraphStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object1')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jStorageProcessor:
|
||||
|
|
@ -257,10 +258,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = IRI
|
||||
triple.o.iri = "http://example.com/object"
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -327,10 +330,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triple with literal object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "literal value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -398,16 +403,20 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = "http://example.com/subject1"
|
||||
triple1.p.value = "http://example.com/predicate1"
|
||||
triple1.o.value = "http://example.com/object1"
|
||||
triple1.o.is_uri = True
|
||||
|
||||
triple1.s.type = IRI
|
||||
triple1.s.iri = "http://example.com/subject1"
|
||||
triple1.p.type = IRI
|
||||
triple1.p.iri = "http://example.com/predicate1"
|
||||
triple1.o.type = IRI
|
||||
triple1.o.iri = "http://example.com/object1"
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = "http://example.com/subject2"
|
||||
triple2.p.value = "http://example.com/predicate2"
|
||||
triple2.s.type = IRI
|
||||
triple2.s.iri = "http://example.com/subject2"
|
||||
triple2.p.type = IRI
|
||||
triple2.p.iri = "http://example.com/predicate2"
|
||||
triple2.o.type = LITERAL
|
||||
triple2.o.value = "literal value"
|
||||
triple2.o.is_uri = False
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -550,10 +559,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject with spaces"
|
||||
triple.p.value = "http://example.com/predicate:with/symbols"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject with spaces"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate:with/symbols"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
|
||||
triple.o.is_uri = False
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vertexai
|
||||
Starting simple with one test to get the basics working
|
||||
Updated for google-genai SDK
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
|
|||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Simple test for processor initialization"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test basic processor initialization with mocked dependencies"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization to avoid taskgroup requirement
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
assert hasattr(processor, 'generation_configs') # Cache dictionary
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
assert hasattr(processor, 'client') # genai.Client
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project="test-project-123",
|
||||
location="us-central1",
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Gemini"
|
||||
mock_response.usage_metadata.prompt_token_count = 15
|
||||
mock_response.usage_metadata.candidates_token_count = 8
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
# Check that the method was called (actual prompt format may vary)
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# Generation config is now created dynamically per model
|
||||
assert 'generation_config' in call_args[1]
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test handling of blocked content (safety filters)"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Blocked content returns None
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 0
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
|
||||
"""Test processor initialization without private key (uses default credentials)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
||||
# Mock google.auth.default() to return credentials and project ID
|
||||
mock_credentials = MagicMock()
|
||||
mock_auth_default.return_value = (mock_credentials, "test-project-123")
|
||||
|
||||
# Mock GenerativeModel
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
|
|
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
project='test-project-123'
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project="test-project-123",
|
||||
location="us-central1",
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = Exception("Network error")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.side_effect = Exception("Network error")
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(Exception, match="Network error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
|
||||
# Verify that generation_config cache exists
|
||||
assert hasattr(processor, 'generation_configs')
|
||||
assert processor.generation_configs == {} # Empty cache initially
|
||||
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
||||
|
||||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
# Verify that api_params dict has the correct values (this is accessible)
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
|
||||
# Verify that api_params dict has the correct values
|
||||
assert processor.api_params["temperature"] == 0.7
|
||||
assert processor.api_params["max_output_tokens"] == 4096
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test that VertexAI is initialized correctly with credentials"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify VertexAI init was called with correct parameters
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
# Verify genai.Client was called with correct parameters
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project='test-project-123',
|
||||
location='europe-west1',
|
||||
credentials=mock_credentials,
|
||||
project='test-project-123'
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
# GenerativeModel is now created lazily on first use, not at initialization
|
||||
mock_generative_model.assert_not_called()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the model was called with the combined empty prompts
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
# Verify the client was called
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
|
||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
|
||||
"""Test Anthropic processor initialization with private key credentials"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-456"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
# Mock AnthropicVertex
|
||||
mock_anthropic_client = MagicMock()
|
||||
mock_anthropic_vertex.return_value = mock_anthropic_client
|
||||
|
|
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||
# is_anthropic logic is now determined dynamically per request
|
||||
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
||||
# Verify AnthropicVertex was initialized with credentials
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
|
||||
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
|
||||
mock_anthropic_vertex.assert_called_once_with(
|
||||
region='us-west1',
|
||||
project_id='test-project-456',
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
|
||||
# Verify api_params are set correctly
|
||||
assert processor.api_params["temperature"] == 0.5
|
||||
assert processor.api_params["max_output_tokens"] == 2048
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'temperature': 0.2,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
|
|
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
# Verify the call was made with the override model
|
||||
call_args = mock_client.models.generate_content.call_args
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-pro"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
# Verify the model override was used
|
||||
call_args = mock_client.models.generate_content.call_args
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue