Merge 2.0 to master (#651)

This commit is contained in:
cybermaggedon 2026-02-28 11:03:14 +00:00 committed by GitHub
parent 3666ece2c5
commit b9d7bf9a8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
212 changed files with 13940 additions and 6180 deletions

View file

@ -88,8 +88,13 @@ async def test_subscriber_deferred_acknowledgment_success():
@pytest.mark.asyncio
async def test_subscriber_deferred_acknowledgment_failure():
"""Verify Subscriber negative acks on delivery failure."""
async def test_subscriber_dropped_message_still_acks():
"""Verify Subscriber acks even when message is dropped (backpressure).
This prevents redelivery storms on shared topics - if we negative_ack
a dropped message, it gets redelivered to all subscribers, none of
whom can handle it either, causing a tight redelivery loop.
"""
mock_backend = MagicMock()
mock_consumer = MagicMock()
mock_backend.create_consumer.return_value = mock_consumer
@ -103,24 +108,66 @@ async def test_subscriber_deferred_acknowledgment_failure():
max_size=1, # Very small queue
backpressure_strategy="drop_new"
)
# Start subscriber to initialize consumer
await subscriber.start()
# Create queue and fill it
queue = await subscriber.subscribe("test-queue")
await queue.put({"existing": "data"})
# Create mock message - should be dropped
msg = create_mock_message("msg-1", {"data": "test"})
# Process message (should fail due to full queue + drop_new strategy)
# Create mock message - should be dropped due to full queue
msg = create_mock_message("test-queue", {"data": "test"})
# Process message (should be dropped due to full queue + drop_new strategy)
await subscriber._process_message(msg)
# Should negative acknowledge failed delivery
mock_consumer.negative_acknowledge.assert_called_once_with(msg)
mock_consumer.acknowledge.assert_not_called()
# Should acknowledge even though delivery failed - prevents redelivery storm
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Clean up
await subscriber.stop()
@pytest.mark.asyncio
async def test_subscriber_orphaned_message_acks():
"""Verify Subscriber acks orphaned messages (no matching waiter).
On shared response topics, if a message arrives for a waiter that
no longer exists (e.g., client disconnected, request timed out),
we must acknowledge it to prevent redelivery storms.
"""
mock_backend = MagicMock()
mock_consumer = MagicMock()
mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber(
backend=mock_backend,
topic="test-topic",
subscription="test-subscription",
consumer_name="test-consumer",
schema=dict,
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
# Don't create any queues - message will be orphaned
# This simulates a response arriving after the waiter has unsubscribed
# Create mock message with an ID that has no matching waiter
msg = create_mock_message("non-existent-waiter-id", {"data": "orphaned"})
# Process message (should be orphaned - no matching waiter)
await subscriber._process_message(msg)
# Should acknowledge orphaned message - prevents redelivery storm
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Clean up
await subscriber.stop()

View file

@ -55,6 +55,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection="sales_data",
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -92,6 +95,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection=None, # No collection specified
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -132,6 +138,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection='sales_data',
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -201,6 +210,144 @@ class TestSetToolStructuredQuery:
assert 'Exception:' in printed_output or 'invalid choice:' in printed_output.lower()
class TestSetToolRowEmbeddingsQuery:
"""Test the set_tool function with row-embeddings-query type."""
@patch('trustgraph.cli.set_tool.Api')
def test_set_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
"""Test setting a row-embeddings-query tool with all parameters."""
mock_api_class.return_value, mock_config = mock_api
mock_config.get.return_value = []
set_tool(
url="http://test.com",
id="customer_search",
name="find_customer",
description="Find customers by name using semantic search",
type="row-embeddings-query",
mcp_tool=None,
collection="sales",
template=None,
schema_name="customers",
index_name="full_name",
limit=20,
arguments=[],
group=None,
state=None,
applicable_states=None
)
captured = capsys.readouterr()
assert "Tool set." in captured.out
# Verify the tool was stored correctly
call_args = mock_config.put.call_args[0][0]
assert len(call_args) == 1
config_value = call_args[0]
assert config_value.type == "tool"
assert config_value.key == "customer_search"
stored_tool = json.loads(config_value.value)
assert stored_tool["name"] == "find_customer"
assert stored_tool["type"] == "row-embeddings-query"
assert stored_tool["collection"] == "sales"
assert stored_tool["schema-name"] == "customers"
assert stored_tool["index-name"] == "full_name"
assert stored_tool["limit"] == 20
@patch('trustgraph.cli.set_tool.Api')
def test_set_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
"""Test setting row-embeddings-query tool with minimal parameters."""
mock_api_class.return_value, mock_config = mock_api
mock_config.get.return_value = []
set_tool(
url="http://test.com",
id="product_search",
name="find_product",
description="Find products using semantic search",
type="row-embeddings-query",
mcp_tool=None,
collection=None,
template=None,
schema_name="products",
index_name=None, # No index filter
limit=None, # Use default
arguments=[],
group=None,
state=None,
applicable_states=None
)
captured = capsys.readouterr()
assert "Tool set." in captured.out
call_args = mock_config.put.call_args[0][0]
stored_tool = json.loads(call_args[0].value)
assert stored_tool["type"] == "row-embeddings-query"
assert stored_tool["schema-name"] == "products"
assert "index-name" not in stored_tool # Should not be included if None
assert "limit" not in stored_tool # Should not be included if None
assert "collection" not in stored_tool # Should not be included if None
def test_set_main_row_embeddings_query_with_all_options(self):
"""Test set main() with row-embeddings-query tool type and all options."""
test_args = [
'tg-set-tool',
'--id', 'customer_search',
'--name', 'find_customer',
'--type', 'row-embeddings-query',
'--description', 'Find customers by name',
'--schema-name', 'customers',
'--collection', 'sales',
'--index-name', 'full_name',
'--limit', '25',
'--api-url', 'http://custom.com'
]
with patch('sys.argv', test_args), \
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
set_main()
mock_set.assert_called_once_with(
url='http://custom.com',
id='customer_search',
name='find_customer',
description='Find customers by name',
type='row-embeddings-query',
mcp_tool=None,
collection='sales',
template=None,
schema_name='customers',
index_name='full_name',
limit=25,
arguments=[],
group=None,
state=None,
applicable_states=None,
token=None
)
def test_valid_types_includes_row_embeddings_query(self):
"""Test that 'row-embeddings-query' is included in valid tool types."""
test_args = [
'tg-set-tool',
'--id', 'test_tool',
'--name', 'test_tool',
'--type', 'row-embeddings-query',
'--description', 'Test tool',
'--schema-name', 'test_schema'
]
with patch('sys.argv', test_args), \
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
# Should not raise an exception about invalid type
set_main()
mock_set.assert_called_once()
class TestShowToolsStructuredQuery:
"""Test the show_tools function with structured-query tools."""
@ -259,9 +406,9 @@ class TestShowToolsStructuredQuery:
@patch('trustgraph.cli.show_tools.Api')
def test_show_mixed_tool_types(self, mock_api_class, mock_api, capsys):
"""Test displaying multiple tool types including structured-query."""
"""Test displaying multiple tool types including structured-query and row-embeddings-query."""
mock_api_class.return_value, mock_config = mock_api
tools = [
{
"name": "ask_knowledge",
@ -270,37 +417,47 @@ class TestShowToolsStructuredQuery:
"collection": "docs"
},
{
"name": "query_data",
"name": "query_data",
"description": "Query structured data",
"type": "structured-query",
"collection": "sales"
},
{
"name": "find_customer",
"description": "Find customers by semantic search",
"type": "row-embeddings-query",
"schema-name": "customers",
"collection": "crm"
},
{
"name": "complete_text",
"description": "Generate text",
"type": "text-completion"
}
]
config_values = [
ConfigValue(type="tool", key=f"tool_{i}", value=json.dumps(tool))
for i, tool in enumerate(tools)
]
mock_config.get_values.return_value = config_values
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# All tool types should be displayed
assert "knowledge-query" in output
assert "structured-query" in output
assert "structured-query" in output
assert "row-embeddings-query" in output
assert "text-completion" in output
# Collections should be shown for appropriate tools
assert "docs" in output # knowledge-query collection
assert "sales" in output # structured-query collection
assert "crm" in output # row-embeddings-query collection
assert "customers" in output # row-embeddings-query schema-name
def test_show_main_parses_args_correctly(self):
"""Test that show main() parses arguments correctly."""
@ -317,6 +474,76 @@ class TestShowToolsStructuredQuery:
mock_show.assert_called_once_with(url='http://custom.com', token=None)
class TestShowToolsRowEmbeddingsQuery:
"""Test the show_tools function with row-embeddings-query tools."""
@patch('trustgraph.cli.show_tools.Api')
def test_show_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
"""Test displaying a row-embeddings-query tool with all fields."""
mock_api_class.return_value, mock_config = mock_api
tool_config = {
"name": "find_customer",
"description": "Find customers by name using semantic search",
"type": "row-embeddings-query",
"collection": "sales",
"schema-name": "customers",
"index-name": "full_name",
"limit": 20
}
config_value = ConfigValue(
type="tool",
key="customer_search",
value=json.dumps(tool_config)
)
mock_config.get_values.return_value = [config_value]
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# Check that tool information is displayed
assert "customer_search" in output
assert "find_customer" in output
assert "row-embeddings-query" in output
assert "sales" in output # Collection
assert "customers" in output # Schema name
assert "full_name" in output # Index name
assert "20" in output # Limit
@patch('trustgraph.cli.show_tools.Api')
def test_show_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
"""Test displaying row-embeddings-query tool with minimal fields."""
mock_api_class.return_value, mock_config = mock_api
tool_config = {
"name": "find_product",
"description": "Find products using semantic search",
"type": "row-embeddings-query",
"schema-name": "products"
# No collection, index-name, or limit
}
config_value = ConfigValue(
type="tool",
key="product_search",
value=json.dumps(tool_config)
)
mock_config.get_values.return_value = [config_value]
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# Should display the tool with schema-name
assert "product_search" in output
assert "row-embeddings-query" in output
assert "products" in output # Schema name
class TestStructuredQueryToolValidation:
"""Test validation specific to structured-query tools."""

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import call
from trustgraph.cores.knowledge import KnowledgeManager
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Value, EntityEmbeddings
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
@pytest.fixture
@ -71,15 +71,15 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
user="test-user",
collection="default", # This should be overridden
metadata=[]
),
triples=[
Triple(
s=Value(value="http://example.org/john", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value="John Smith", is_uri=False)
s=Term(type=IRI, iri="http://example.org/john"),
p=Term(type=IRI, iri="http://example.org/name"),
o=Term(type=LITERAL, value="John Smith")
)
]
)
@ -97,7 +97,7 @@ def sample_graph_embeddings():
),
entities=[
EntityEmbeddings(
entity=Value(value="http://example.org/john", is_uri=True),
entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]]
)
]

View file

@ -0,0 +1,599 @@
"""
Unit tests for EntityCentricKnowledgeGraph class
Tests the entity-centric knowledge graph implementation without requiring
an actual Cassandra connection. Uses mocking to verify correct behavior.
"""
import pytest
from unittest.mock import MagicMock, patch, call
import os
class TestEntityCentricKnowledgeGraph:
"""Test cases for EntityCentricKnowledgeGraph"""
@pytest.fixture
def mock_cluster(self):
"""Create a mock Cassandra cluster"""
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cluster_cls:
mock_cluster = MagicMock()
mock_session = MagicMock()
mock_cluster.connect.return_value = mock_session
mock_cluster_cls.return_value = mock_cluster
yield mock_cluster_cls, mock_cluster, mock_session
@pytest.fixture
def entity_kg(self, mock_cluster):
"""Create an EntityCentricKnowledgeGraph instance with mocked Cassandra"""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
# Create instance
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
return kg, mock_session
def test_init_creates_entity_centric_schema(self, mock_cluster):
"""Test that initialization creates the 2-table entity-centric schema"""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
# Verify schema tables were created
execute_calls = mock_session.execute.call_args_list
executed_statements = [str(c) for c in execute_calls]
# Check for keyspace creation
keyspace_created = any('create keyspace' in str(c).lower() for c in execute_calls)
assert keyspace_created
# Check for quads_by_entity table
entity_table_created = any('quads_by_entity' in str(c) for c in execute_calls)
assert entity_table_created
# Check for quads_by_collection table
collection_table_created = any('quads_by_collection' in str(c) for c in execute_calls)
assert collection_table_created
# Check for collection_metadata table
metadata_table_created = any('collection_metadata' in str(c) for c in execute_calls)
assert metadata_table_created
def test_prepare_statements_initialized(self, entity_kg):
"""Test that prepared statements are initialized"""
kg, mock_session = entity_kg
# Verify prepare was called for various statements
assert mock_session.prepare.called
prepare_calls = mock_session.prepare.call_args_list
# Check that key prepared statements exist
prepared_queries = [str(c) for c in prepare_calls]
# Insert statements
insert_entity_stmt = any('INSERT INTO' in str(c) and 'quads_by_entity' in str(c)
for c in prepare_calls)
assert insert_entity_stmt
insert_collection_stmt = any('INSERT INTO' in str(c) and 'quads_by_collection' in str(c)
for c in prepare_calls)
assert insert_collection_stmt
def test_insert_uri_object_creates_4_entity_rows(self, entity_kg):
"""Test that inserting a quad with URI object creates 4 entity rows"""
kg, mock_session = entity_kg
# Reset mocks to track only insert-related calls
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g='http://example.org/graph1',
otype='u'
)
# Verify batch was executed
mock_session.execute.assert_called()
def test_insert_literal_object_creates_3_entity_rows(self, entity_kg):
"""Test that inserting a quad with literal object creates 3 entity rows"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://www.w3.org/2000/01/rdf-schema#label',
o='Alice Smith',
g=None,
otype='l',
dtype='xsd:string',
lang='en'
)
# Verify batch was executed
mock_session.execute.assert_called()
def test_insert_default_graph(self, entity_kg):
"""Test that None graph is stored as empty string"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g=None,
otype='u'
)
mock_session.execute.assert_called()
def test_insert_auto_detects_otype(self, entity_kg):
"""Test that otype is auto-detected when not provided"""
kg, mock_session = entity_kg
mock_session.reset_mock()
# URI should be auto-detected
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob'
)
mock_session.execute.assert_called()
mock_session.reset_mock()
# Literal should be auto-detected
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/name',
o='Alice'
)
mock_session.execute.assert_called()
def test_get_s_returns_quads_for_subject(self, entity_kg):
"""Test get_s queries by subject"""
kg, mock_session = entity_kg
# Mock the query result
mock_result = [
MagicMock(p='http://example.org/knows', o='http://example.org/Bob',
d='', otype='u', dtype='', lang='', s='http://example.org/Alice')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice')
# Verify query was executed
mock_session.execute.assert_called()
# Results should be QuadResult objects
assert len(results) == 1
assert results[0].s == 'http://example.org/Alice'
assert results[0].p == 'http://example.org/knows'
assert results[0].o == 'http://example.org/Bob'
def test_get_p_returns_quads_for_predicate(self, entity_kg):
"""Test get_p queries by predicate"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', o='http://example.org/Bob',
d='', otype='u', dtype='', lang='', p='http://example.org/knows')
]
mock_session.execute.return_value = mock_result
results = kg.get_p('test_collection', 'http://example.org/knows')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_o_returns_quads_for_object(self, entity_kg):
"""Test get_o queries by object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
d='', otype='u', dtype='', lang='', o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_o('test_collection', 'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_sp_returns_quads_for_subject_predicate(self, entity_kg):
"""Test get_sp queries by subject and predicate"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(o='http://example.org/Bob', d='', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_sp('test_collection', 'http://example.org/Alice',
'http://example.org/knows')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_po_returns_quads_for_predicate_object(self, entity_kg):
"""Test get_po queries by predicate and object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', d='', otype='u', dtype='', lang='',
o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_po('test_collection', 'http://example.org/knows',
'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_os_returns_quads_for_object_subject(self, entity_kg):
"""Test get_os queries by object and subject"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='', otype='u', dtype='', lang='',
s='http://example.org/Alice', o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_os('test_collection', 'http://example.org/Bob',
'http://example.org/Alice')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_spo_returns_quads_for_subject_predicate_object(self, entity_kg):
"""Test get_spo queries by subject, predicate, and object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(d='', otype='u', dtype='', lang='',
o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_spo('test_collection', 'http://example.org/Alice',
'http://example.org/knows', 'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_g_returns_quads_for_graph(self, entity_kg):
"""Test get_g queries by graph"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_g('test_collection', 'http://example.org/graph1')
mock_session.execute.assert_called()
def test_get_all_returns_all_quads_in_collection(self, entity_kg):
"""Test get_all returns all quads"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_all('test_collection')
mock_session.execute.assert_called()
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
"""Test that g='*' returns quads from all graphs"""
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Bob'),
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Charlie')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
# Should return quads from both graphs
assert len(results) == 2
def test_specific_graph_filters_results(self, entity_kg):
"""Test that specifying a graph filters results"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Bob'),
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Charlie')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice',
g='http://example.org/graph1')
# Should only return quads from graph1
assert len(results) == 1
assert results[0].g == 'http://example.org/graph1'
def test_collection_exists_returns_true_when_exists(self, entity_kg):
"""Test collection_exists returns True for existing collection"""
kg, mock_session = entity_kg
mock_result = [MagicMock(collection='test_collection')]
mock_session.execute.return_value = mock_result
exists = kg.collection_exists('test_collection')
assert exists is True
def test_collection_exists_returns_false_when_not_exists(self, entity_kg):
"""Test collection_exists returns False for non-existing collection"""
kg, mock_session = entity_kg
mock_session.execute.return_value = []
exists = kg.collection_exists('nonexistent_collection')
assert exists is False
def test_create_collection_inserts_metadata(self, entity_kg):
"""Test create_collection inserts metadata row"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.create_collection('test_collection')
# Verify INSERT was executed for collection_metadata
mock_session.execute.assert_called()
def test_delete_collection_removes_all_data(self, entity_kg):
"""Test delete_collection removes entity partitions and collection rows"""
kg, mock_session = entity_kg
# Mock reading quads from collection
mock_quads = [
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u')
]
mock_session.execute.return_value = mock_quads
mock_session.reset_mock()
kg.delete_collection('test_collection')
# Verify delete operations were executed
assert mock_session.execute.called
def test_close_shuts_down_connections(self, entity_kg):
"""Test close shuts down session and cluster"""
kg, mock_session = entity_kg
kg.close()
mock_session.shutdown.assert_called_once()
kg.cluster.shutdown.assert_called_once()
class TestQuadResult:
"""Test cases for QuadResult class"""
def test_quad_result_stores_all_fields(self):
"""Test QuadResult stores all quad fields"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g='http://example.org/graph1',
otype='u',
dtype='',
lang=''
)
assert result.s == 'http://example.org/Alice'
assert result.p == 'http://example.org/knows'
assert result.o == 'http://example.org/Bob'
assert result.g == 'http://example.org/graph1'
assert result.otype == 'u'
assert result.dtype == ''
assert result.lang == ''
def test_quad_result_defaults(self):
"""Test QuadResult default values"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/s',
p='http://example.org/p',
o='literal value',
g=''
)
assert result.otype == 'u' # Default otype
assert result.dtype == ''
assert result.lang == ''
def test_quad_result_with_literal_metadata(self):
"""Test QuadResult with literal metadata"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/Alice',
p='http://www.w3.org/2000/01/rdf-schema#label',
o='Alice Smith',
g='',
otype='l',
dtype='xsd:string',
lang='en'
)
assert result.otype == 'l'
assert result.dtype == 'xsd:string'
assert result.lang == 'en'
class TestWriteHelperFunctions:
"""Test cases for helper functions in write.py"""
def test_get_term_otype_for_iri(self):
"""Test get_term_otype returns 'u' for IRI terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_otype(term) == 'u'
def test_get_term_otype_for_literal(self):
"""Test get_term_otype returns 'l' for LITERAL terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='Alice Smith')
assert get_term_otype(term) == 'l'
def test_get_term_otype_for_blank(self):
"""Test get_term_otype returns 'u' for BLANK terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, BLANK
term = Term(type=BLANK, id='_:b1')
assert get_term_otype(term) == 'u'
def test_get_term_otype_for_triple(self):
"""Test get_term_otype returns 't' for TRIPLE terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, TRIPLE
term = Term(type=TRIPLE)
assert get_term_otype(term) == 't'
def test_get_term_otype_for_none(self):
"""Test get_term_otype returns 'u' for None"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
assert get_term_otype(None) == 'u'
def test_get_term_dtype_for_literal(self):
"""Test get_term_dtype extracts datatype from LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='42', datatype='xsd:integer')
assert get_term_dtype(term) == 'xsd:integer'
def test_get_term_dtype_for_non_literal(self):
"""Test get_term_dtype returns empty string for non-LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_dtype(term) == ''
def test_get_term_dtype_for_none(self):
"""Test get_term_dtype returns empty string for None"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
assert get_term_dtype(None) == ''
def test_get_term_lang_for_literal(self):
"""Test get_term_lang extracts language from LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_lang
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='Alice Smith', language='en')
assert get_term_lang(term) == 'en'
def test_get_term_lang_for_non_literal(self):
"""Test get_term_lang returns empty string for non-LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_lang
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_lang(term) == ''
class TestServiceHelperFunctions:
"""Test cases for helper functions in service.py"""
def test_create_term_with_uri_otype(self):
"""Test create_term creates IRI Term for otype='u'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice', otype='u')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_with_literal_otype(self):
"""Test create_term creates LITERAL Term for otype='l'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
assert term.type == LITERAL
assert term.value == 'Alice Smith'
assert term.datatype == 'xsd:string'
assert term.language == 'en'
def test_create_term_with_triple_otype(self):
"""Test create_term creates IRI Term for otype='t'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/statement1', otype='t')
assert term.type == IRI
assert term.iri == 'http://example.org/statement1'
def test_create_term_heuristic_fallback_uri(self):
"""Test create_term uses URL heuristic when otype not provided"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_heuristic_fallback_literal(self):
"""Test create_term uses literal heuristic when otype not provided"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith')
assert term.type == LITERAL
assert term.value == 'Alice Smith'

View file

@ -0,0 +1,380 @@
"""
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
Tests the Stage 1 processor that computes embeddings for row index fields.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
"""Test row embeddings processor functionality"""
async def test_processor_initialization(self):
"""Test basic processor initialization"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings'
}
processor = Processor(**config)
assert hasattr(processor, 'schemas')
assert processor.schemas == {}
assert processor.batch_size == 10 # default
async def test_processor_initialization_with_custom_batch_size(self):
"""Test processor initialization with custom batch size"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings',
'batch_size': 25
}
processor = Processor(**config)
assert processor.batch_size == 25
async def test_get_index_names_single_index(self):
"""Test getting index names with single indexed field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
Field(name='email', type='text', indexed=False),
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed field
assert 'id' in index_names
assert 'name' in index_names
assert 'email' not in index_names
async def test_get_index_names_no_indexes(self):
"""Test getting index names when no fields are indexed"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='logs',
description='Log records',
fields=[
Field(name='timestamp', type='text'),
Field(name='message', type='text'),
]
)
index_names = processor.get_index_names(schema)
assert index_names == []
async def test_build_index_value_single_field(self):
"""Test building index value for single field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'id': 'CUST001',
'name': 'John Doe',
'email': 'john@example.com'
}
result = processor.build_index_value(value_map, 'name')
assert result == ['John Doe']
async def test_build_index_value_composite_index(self):
"""Test building index value for composite index"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'first_name': 'John',
'last_name': 'Doe',
'city': 'New York'
}
result = processor.build_index_value(value_map, 'first_name, last_name')
assert result == ['John', 'Doe']
async def test_build_index_value_missing_field(self):
"""Test building index value when field is missing"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'name': 'John Doe'
}
result = processor.build_index_value(value_map, 'missing_field')
assert result == ['']
async def test_build_text_for_embedding_single_value(self):
"""Test building text representation for single value"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John Doe'])
assert result == 'John Doe'
async def test_build_text_for_embedding_multiple_values(self):
"""Test building text representation for multiple values"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
assert result == 'John Doe NYC'
async def test_on_schema_config_loads_schemas(self):
"""Test that schema configuration is loaded correctly"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
schema_def = {
'name': 'customers',
'description': 'Customer records',
'fields': [
{'name': 'id', 'type': 'text', 'primary_key': True},
{'name': 'name', 'type': 'text', 'indexed': True},
{'name': 'email', 'type': 'text'}
]
}
config_data = {
'schema': {
'customers': json.dumps(schema_def)
}
}
await processor.on_schema_config(config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
config_data = {
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
assert processor.schemas == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_drops_unknown_schema(self):
"""Test that messages for unknown schemas are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='unknown_schema',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_processes_embeddings(self):
"""Test processing a message and computing embeddings"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject, RowSchema, Field
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[
{'id': 'CUST001', 'name': 'John Doe'},
{'id': 'CUST002', 'name': 'Jane Smith'}
]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
# Mock the flow
mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
mock_output = AsyncMock()
def flow_factory(name):
if name == 'embeddings-request':
return mock_embeddings_request
elif name == 'output':
return mock_output
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text
# 4 values: CUST001, John Doe, CUST002, Jane Smith
assert mock_embeddings_request.embed.call_count == 4
# Should have sent output
mock_output.send.assert_called()
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -7,7 +7,7 @@ collecting labels and definitions for entity embedding and retrieval.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
from trustgraph.schema.knowledge.graph import EntityContext
@ -25,9 +25,9 @@ class TestEntityContextBuilding:
"""Test that entity context is built from rdfs:label."""
triples = [
Triple(
s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/cornish-pasty"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Cornish Pasty")
)
]
@ -35,16 +35,16 @@ class TestEntityContextBuilding:
assert len(contexts) == 1, "Should create one entity context"
assert isinstance(contexts[0], EntityContext)
assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty"
assert contexts[0].entity.iri == "https://example.com/entity/cornish-pasty"
assert "Label: Cornish Pasty" in contexts[0].context
def test_builds_context_from_definition(self, processor):
"""Test that entity context includes definitions."""
triples = [
Triple(
s=Value(value="https://example.com/entity/pasty", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="A baked pastry filled with savory ingredients", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/pasty"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="A baked pastry filled with savory ingredients")
)
]
@ -57,14 +57,14 @@ class TestEntityContextBuilding:
"""Test that label and definition are combined in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Pasty Recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Pasty Recipe")
),
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional Cornish pastry recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Traditional Cornish pastry recipe")
)
]
@ -80,14 +80,14 @@ class TestEntityContextBuilding:
"""Test that only the first label is used in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="First Label", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="First Label")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Second Label", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Second Label")
)
]
@ -101,14 +101,14 @@ class TestEntityContextBuilding:
"""Test that all definitions are included in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="First definition", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="First definition")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Second definition", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Second definition")
)
]
@ -123,9 +123,9 @@ class TestEntityContextBuilding:
"""Test that schema.org description is treated as definition."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="https://schema.org/description", is_uri=True),
o=Value(value="A delicious food item", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="https://schema.org/description"),
o=Term(type=LITERAL, value="A delicious food item")
)
]
@ -138,26 +138,26 @@ class TestEntityContextBuilding:
"""Test that contexts are created for multiple entities."""
triples = [
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity One")
),
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Two", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity Two")
),
Triple(
s=Value(value="https://example.com/entity/entity3", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Three", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity3"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity Three")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 3, "Should create context for each entity"
entity_uris = [ctx.entity.value for ctx in contexts]
entity_uris = [ctx.entity.iri for ctx in contexts]
assert "https://example.com/entity/entity1" in entity_uris
assert "https://example.com/entity/entity2" in entity_uris
assert "https://example.com/entity/entity3" in entity_uris
@ -166,9 +166,9 @@ class TestEntityContextBuilding:
"""Test that URI objects are ignored (only literal labels/definitions)."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=IRI, iri="https://example.com/some/uri") # URI, not literal
)
]
@ -181,14 +181,14 @@ class TestEntityContextBuilding:
"""Test that other predicates are ignored."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Food")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/food2", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://example.com/produces"),
o=Term(type=IRI, iri="https://example.com/entity/food2")
)
]
@ -205,29 +205,29 @@ class TestEntityContextBuilding:
assert len(contexts) == 0, "Empty triple list should return empty contexts"
def test_entity_context_has_value_object(self, processor):
"""Test that EntityContext.entity is a Value object."""
def test_entity_context_has_term_object(self, processor):
"""Test that EntityContext.entity is a Term object."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test Entity")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert isinstance(contexts[0].entity, Value), "Entity should be Value object"
assert contexts[0].entity.is_uri, "Entity should be marked as URI"
assert isinstance(contexts[0].entity, Term), "Entity should be Term object"
assert contexts[0].entity.type == IRI, "Entity should be IRI type"
def test_entity_context_text_is_string(self, processor):
"""Test that EntityContext.context is a string."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test Entity")
)
]
@ -241,22 +241,22 @@ class TestEntityContextBuilding:
triples = [
# Entity with label - should create context
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity One")
),
# Entity with only rdf:type - should NOT create context
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Food")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1, "Should only create context for entity with label/definition"
assert contexts[0].entity.value == "https://example.com/entity/entity1"
assert contexts[0].entity.iri == "https://example.com/entity/entity1"
class TestEntityContextEdgeCases:
@ -266,9 +266,9 @@ class TestEntityContextEdgeCases:
"""Test handling of unicode characters in labels."""
triples = [
Triple(
s=Value(value="https://example.com/entity/café", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Café Spécial", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/café"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Café Spécial")
)
]
@ -282,9 +282,9 @@ class TestEntityContextEdgeCases:
long_def = "This is a very long definition " * 50
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value=long_def, is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value=long_def)
)
]
@ -297,9 +297,9 @@ class TestEntityContextEdgeCases:
"""Test handling of special characters in context text."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test & Entity <with> \"quotes\"", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test & Entity <with> \"quotes\"")
)
]
@ -313,27 +313,27 @@ class TestEntityContextEdgeCases:
triples = [
# Label - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty Recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Cornish Pasty Recipe")
),
# Type - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Recipe", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Recipe")
),
# Property - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/pasty", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://example.com/produces"),
o=Term(type=IRI, iri="https://example.com/entity/pasty")
),
# Definition - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional British pastry recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Traditional British pastry recipe")
)
]

View file

@ -9,7 +9,7 @@ the knowledge graph.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
@pytest.fixture
@ -92,12 +92,12 @@ class TestOntologyTripleGeneration:
# Find type triples for Recipe class
recipe_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class"
assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \
assert recipe_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#Class", \
"Class type should be owl:Class"
def test_generates_class_labels(self, extractor, sample_ontology_subset):
@ -107,14 +107,14 @@ class TestOntologyTripleGeneration:
# Find label triples for Recipe class
recipe_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(recipe_label_triples) == 1, "Should generate label triple for class"
assert recipe_label_triples[0].o.value == "Recipe", \
"Label should match class label from ontology"
assert not recipe_label_triples[0].o.is_uri, \
assert recipe_label_triples[0].o.type == LITERAL, \
"Label should be a literal, not URI"
def test_generates_class_comments(self, extractor, sample_ontology_subset):
@ -124,8 +124,8 @@ class TestOntologyTripleGeneration:
# Find comment triples for Recipe class
recipe_comment_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#comment"
]
assert len(recipe_comment_triples) == 1, "Should generate comment triple for class"
@ -139,13 +139,13 @@ class TestOntologyTripleGeneration:
# Find type triples for ingredients property
ingredients_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(ingredients_type_triples) == 1, \
"Should generate exactly one type triple per object property"
assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \
assert ingredients_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#ObjectProperty", \
"Object property type should be owl:ObjectProperty"
def test_generates_object_property_labels(self, extractor, sample_ontology_subset):
@ -155,8 +155,8 @@ class TestOntologyTripleGeneration:
# Find label triples for ingredients property
ingredients_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(ingredients_label_triples) == 1, \
@ -171,15 +171,15 @@ class TestOntologyTripleGeneration:
# Find domain triples for ingredients property
ingredients_domain_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#domain"
]
assert len(ingredients_domain_triples) == 1, \
"Should generate domain triple for object property"
assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \
assert ingredients_domain_triples[0].o.iri == "http://purl.org/ontology/fo/Recipe", \
"Domain should be Recipe class URI"
assert ingredients_domain_triples[0].o.is_uri, \
assert ingredients_domain_triples[0].o.type == IRI, \
"Domain should be a URI reference"
def test_generates_object_property_range(self, extractor, sample_ontology_subset):
@ -189,13 +189,13 @@ class TestOntologyTripleGeneration:
# Find range triples for produces property
produces_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/produces"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
if t.s.iri == "http://purl.org/ontology/fo/produces"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(produces_range_triples) == 1, \
"Should generate range triple for object property"
assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \
assert produces_range_triples[0].o.iri == "http://purl.org/ontology/fo/Food", \
"Range should be Food class URI"
def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset):
@ -205,13 +205,13 @@ class TestOntologyTripleGeneration:
# Find type triples for serves property
serves_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/serves"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(serves_type_triples) == 1, \
"Should generate exactly one type triple per datatype property"
assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
assert serves_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
"Datatype property type should be owl:DatatypeProperty"
def test_generates_datatype_property_range(self, extractor, sample_ontology_subset):
@ -221,13 +221,13 @@ class TestOntologyTripleGeneration:
# Find range triples for serves property
serves_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
if t.s.iri == "http://purl.org/ontology/fo/serves"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(serves_range_triples) == 1, \
"Should generate range triple for datatype property"
assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \
assert serves_range_triples[0].o.iri == "http://www.w3.org/2001/XMLSchema#string", \
"Range should be XSD type URI (xsd:string expanded)"
def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset):
@ -236,9 +236,9 @@ class TestOntologyTripleGeneration:
# Count unique class subjects
class_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and t.o.value == "http://www.w3.org/2002/07/owl#Class"
t.s.iri for t in triples
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and t.o.iri == "http://www.w3.org/2002/07/owl#Class"
)
assert len(class_subjects) == 3, \
@ -250,9 +250,9 @@ class TestOntologyTripleGeneration:
# Count unique property subjects (object + datatype properties)
property_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value)
t.s.iri for t in triples
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and ("ObjectProperty" in t.o.iri or "DatatypeProperty" in t.o.iri)
)
assert len(property_subjects) == 3, \
@ -276,7 +276,7 @@ class TestOntologyTripleGeneration:
# Should still generate proper RDF triples despite dict field names
label_triples = [
t for t in triples
if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(label_triples) > 0, \
"Should generate rdfs:label triples from dict 'labels' field"

View file

@ -8,7 +8,7 @@ and extracts/validates triples from LLM responses.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
@pytest.fixture
@ -248,9 +248,9 @@ class TestTripleParsing:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1, "Should parse one valid triple"
assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty"
assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
assert validated[0].s.iri == "https://trustgraph.ai/food/cornish-pasty"
assert validated[0].p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
def test_parse_multiple_triples(self, extractor, sample_ontology_subset):
"""Test parsing multiple triples."""
@ -307,11 +307,11 @@ class TestTripleParsing:
assert len(validated) == 1
# Subject should be expanded to entity URI
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
# Predicate should be expanded to ontology URI
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
# Object should be expanded to class URI
assert validated[0].o.value == "http://purl.org/ontology/fo/Food"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Food"
def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset):
"""Test that Triple objects are properly created."""
@ -324,12 +324,12 @@ class TestTripleParsing:
assert len(validated) == 1
triple = validated[0]
assert isinstance(triple, Triple), "Should create Triple objects"
assert isinstance(triple.s, Value), "Subject should be Value object"
assert isinstance(triple.p, Value), "Predicate should be Value object"
assert isinstance(triple.o, Value), "Object should be Value object"
assert triple.s.is_uri, "Subject should be marked as URI"
assert triple.p.is_uri, "Predicate should be marked as URI"
assert not triple.o.is_uri, "Object literal should not be marked as URI"
assert isinstance(triple.s, Term), "Subject should be Term object"
assert isinstance(triple.p, Term), "Predicate should be Term object"
assert isinstance(triple.o, Term), "Object should be Term object"
assert triple.s.type == IRI, "Subject should be IRI type"
assert triple.p.type == IRI, "Predicate should be IRI type"
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
class TestURIExpansionInExtraction:
@ -343,8 +343,8 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
assert validated[0].o.is_uri, "Class reference should be URI"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
assert validated[0].o.type == IRI, "Class reference should be URI"
def test_expands_property_names(self, extractor, sample_ontology_subset):
"""Test that property names are expanded to full URIs."""
@ -354,7 +354,7 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
def test_expands_entity_instances(self, extractor, sample_ontology_subset):
"""Test that entity instances get constructed URIs."""
@ -364,8 +364,8 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
assert "my-special-recipe" in validated[0].s.value
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
assert "my-special-recipe" in validated[0].s.iri
class TestEdgeCases:

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestDispatchSerialize:
@ -14,55 +14,55 @@ class TestDispatchSerialize:
def test_to_value_with_uri(self):
"""Test to_value function with URI"""
input_data = {"v": "http://example.com/resource", "e": True}
input_data = {"t": "i", "i": "http://example.com/resource"}
result = to_value(input_data)
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_to_value_with_literal(self):
"""Test to_value function with literal value"""
input_data = {"v": "literal string", "e": False}
input_data = {"t": "l", "v": "literal string"}
result = to_value(input_data)
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_to_subgraph_with_multiple_triples(self):
"""Test to_subgraph function with multiple triples"""
input_data = [
{
"s": {"v": "subject1", "e": True},
"p": {"v": "predicate1", "e": True},
"o": {"v": "object1", "e": False}
"s": {"t": "i", "i": "subject1"},
"p": {"t": "i", "i": "predicate1"},
"o": {"t": "l", "v": "object1"}
},
{
"s": {"v": "subject2", "e": False},
"p": {"v": "predicate2", "e": True},
"o": {"v": "object2", "e": True}
"s": {"t": "l", "v": "subject2"},
"p": {"t": "i", "i": "predicate2"},
"o": {"t": "i", "i": "object2"}
}
]
result = to_subgraph(input_data)
assert len(result) == 2
assert all(isinstance(triple, Triple) for triple in result)
# Check first triple
assert result[0].s.value == "subject1"
assert result[0].s.is_uri is True
assert result[0].p.value == "predicate1"
assert result[0].p.is_uri is True
assert result[0].s.iri == "subject1"
assert result[0].s.type == IRI
assert result[0].p.iri == "predicate1"
assert result[0].p.type == IRI
assert result[0].o.value == "object1"
assert result[0].o.is_uri is False
assert result[0].o.type == LITERAL
# Check second triple
assert result[1].s.value == "subject2"
assert result[1].s.is_uri is False
assert result[1].s.type == LITERAL
def test_to_subgraph_with_empty_list(self):
"""Test to_subgraph function with empty input"""
@ -74,16 +74,16 @@ class TestDispatchSerialize:
def test_serialize_value_with_uri(self):
"""Test serialize_value function with URI value"""
value = Value(value="http://example.com/test", is_uri=True)
result = serialize_value(value)
assert result == {"v": "http://example.com/test", "e": True}
term = Term(type=IRI, iri="http://example.com/test")
result = serialize_value(term)
assert result == {"t": "i", "i": "http://example.com/test"}
def test_serialize_value_with_literal(self):
"""Test serialize_value function with literal value"""
value = Value(value="test literal", is_uri=False)
result = serialize_value(value)
assert result == {"v": "test literal", "e": False}
term = Term(type=LITERAL, value="test literal")
result = serialize_value(term)
assert result == {"t": "l", "v": "test literal"}

View file

@ -1,7 +1,7 @@
"""
Unit tests for objects import dispatcher.
Unit tests for rows import dispatcher.
Tests the business logic of objects import dispatcher
Tests the business logic of rows import dispatcher
while mocking the Publisher and websocket components.
"""
@ -11,7 +11,7 @@ import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from aiohttp import web
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
from trustgraph.gateway.dispatch.rows_import import RowsImport
from trustgraph.schema import Metadata, ExtractedObject
@ -92,16 +92,16 @@ def minimal_objects_message():
}
class TestObjectsImportInitialization:
"""Test ObjectsImport initialization."""
class TestRowsImportInitialization:
"""Test RowsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters."""
"""Test that RowsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
)
# Verify instance variables are set correctly
assert objects_import.ws == mock_websocket
assert objects_import.running == mock_running
assert objects_import.publisher == mock_publisher_instance
assert rows_import.ws == mock_websocket
assert rows_import.running == mock_running
assert rows_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport(
"""Test that RowsImport stores all required references."""
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="objects-queue"
)
assert objects_import.ws is mock_websocket
assert objects_import.running is mock_running
assert rows_import.ws is mock_websocket
assert rows_import.running is mock_running
class TestObjectsImportLifecycle:
"""Test ObjectsImport lifecycle methods."""
class TestRowsImportLifecycle:
"""Test RowsImport lifecycle methods."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that start() calls publisher.start()."""
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.start = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.start()
await rows_import.start()
mock_publisher_instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket."""
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.destroy()
await rows_import.destroy()
# Verify sequence of operations
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
"""Test that destroy() handles None websocket gracefully."""
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
)
# Should not raise exception
await objects_import.destroy()
await rows_import.destroy()
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
class TestObjectsImportMessageProcessing:
"""Test ObjectsImport message processing."""
class TestRowsImportMessageProcessing:
"""Test RowsImport message processing."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly."""
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields."""
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = minimal_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields."""
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = message_data
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Get the sent object and verify defaults
sent_object = mock_publisher_instance.send.call_args[0][1]
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == ""
class TestObjectsImportRunMethod:
"""Test ObjectsImport run method."""
class TestRowsImportRunMethod:
"""Test RowsImport run method."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True."""
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
# Set up running state to return True twice, then False
mock_running.get.side_effect = [True, True, False]
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.run()
await rows_import.run()
# Verify sleep was called twice (for the two True iterations)
assert mock_sleep.call_count == 2
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
mock_websocket.close.assert_called_once()
# Verify websocket was set to None
assert objects_import.ws is None
assert rows_import.ws is None
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
"""Test that run() handles None websocket gracefully."""
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
mock_running.get.return_value = False # Exit immediately
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
)
# Should not raise exception
await objects_import.run()
await rows_import.run()
# Verify websocket remains None
assert objects_import.ws is None
assert rows_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
class TestRowsImportBatchProcessing:
"""Test RowsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport."""
class TestRowsImportErrorHandling:
"""Test error handling in RowsImport."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors."""
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
mock_msg.json.return_value = sample_objects_message
with pytest.raises(Exception, match="Publisher error"):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with pytest.raises(json.JSONDecodeError):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)

View file

@ -6,11 +6,21 @@ import pytest
from unittest.mock import Mock, AsyncMock
# Mock schema classes for testing
class Value:
def __init__(self, value, is_uri, type):
self.value = value
self.is_uri = is_uri
# Term type constants
IRI = "i"
LITERAL = "l"
BLANK = "b"
TRIPLE = "t"
class Term:
def __init__(self, type, iri=None, value=None, id=None, datatype=None, language=None, triple=None):
self.type = type
self.iri = iri
self.value = value
self.id = id
self.datatype = datatype
self.language = language
self.triple = triple
class Triple:
def __init__(self, s, p, o):
@ -66,32 +76,30 @@ def sample_relationships():
@pytest.fixture
def sample_value_uri():
"""Sample URI Value object"""
return Value(
value="http://example.com/person/john-smith",
is_uri=True,
type=""
def sample_term_uri():
"""Sample URI Term object"""
return Term(
type=IRI,
iri="http://example.com/person/john-smith"
)
@pytest.fixture
def sample_value_literal():
"""Sample literal Value object"""
return Value(
value="John Smith",
is_uri=False,
type="string"
def sample_term_literal():
"""Sample literal Term object"""
return Term(
type=LITERAL,
value="John Smith"
)
@pytest.fixture
def sample_triple(sample_value_uri, sample_value_literal):
def sample_triple(sample_term_uri, sample_term_literal):
"""Sample Triple object"""
return Triple(
s=sample_value_uri,
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=sample_value_literal
s=sample_term_uri,
p=Term(type=IRI, iri="http://schema.org/name"),
o=sample_term_literal
)

View file

@ -11,7 +11,7 @@ import json
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.template.prompt_manager import PromptManager
@ -33,7 +33,7 @@ class TestAgentKgExtractor:
# Set up the methods we want to test
extractor.to_uri = real_extractor.to_uri
extractor.parse_json = real_extractor.parse_json
extractor.parse_jsonl = real_extractor.parse_jsonl
extractor.process_extraction_data = real_extractor.process_extraction_data
extractor.emit_triples = real_extractor.emit_triples
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
@ -53,48 +53,49 @@ class TestAgentKgExtractor:
id="doc123",
metadata=[
Triple(
s=Value(value="doc123", is_uri=True),
p=Value(value="http://example.org/type", is_uri=True),
o=Value(value="document", is_uri=False)
s=Term(type=IRI, iri="doc123"),
p=Term(type=IRI, iri="http://example.org/type"),
o=Term(type=LITERAL, value="document")
)
]
)
@pytest.fixture
def sample_extraction_data(self):
"""Sample extraction data in expected format"""
return {
"definitions": [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
],
"relationships": [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "used_in",
"object": "Machine Learning",
"object-entity": True
},
{
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
}
"""Sample extraction data in JSONL format (list with type discriminators)"""
return [
{
"type": "definition",
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"type": "definition",
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
},
{
"type": "relationship",
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"type": "relationship",
"subject": "Neural Networks",
"predicate": "used_in",
"object": "Machine Learning",
"object-entity": True
},
{
"type": "relationship",
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
def test_to_uri_conversion(self, agent_extractor):
"""Test URI conversion for entities"""
@ -113,148 +114,147 @@ class TestAgentKgExtractor:
expected = f"{TRUSTGRAPH_ENTITIES}"
assert uri == expected
def test_parse_json_with_code_blocks(self, agent_extractor):
"""Test JSON parsing from code blocks"""
# Test JSON in code blocks
def test_parse_jsonl_with_code_blocks(self, agent_extractor):
"""Test JSONL parsing from code blocks"""
# Test JSONL in code blocks - note: JSON uses lowercase true/false
response = '''```json
{
"definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}],
"relationships": []
}
```'''
result = agent_extractor.parse_json(response)
assert result["definitions"][0]["entity"] == "AI"
assert result["definitions"][0]["definition"] == "Artificial Intelligence"
assert result["relationships"] == []
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}
{"type": "relationship", "subject": "AI", "predicate": "is", "object": "technology", "object-entity": false}
```'''
def test_parse_json_without_code_blocks(self, agent_extractor):
"""Test JSON parsing without code blocks"""
response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}'''
result = agent_extractor.parse_json(response)
assert result["definitions"][0]["entity"] == "ML"
assert result["definitions"][0]["definition"] == "Machine Learning"
result = agent_extractor.parse_jsonl(response)
def test_parse_json_invalid_format(self, agent_extractor):
"""Test JSON parsing with invalid format"""
invalid_response = "This is not JSON at all"
with pytest.raises(json.JSONDecodeError):
agent_extractor.parse_json(invalid_response)
assert len(result) == 2
assert result[0]["entity"] == "AI"
assert result[0]["definition"] == "Artificial Intelligence"
assert result[1]["type"] == "relationship"
def test_parse_json_malformed_code_blocks(self, agent_extractor):
"""Test JSON parsing with malformed code blocks"""
# Missing closing backticks
response = '''```json
{"definitions": [], "relationships": []}
'''
# Should still parse the JSON content
with pytest.raises(json.JSONDecodeError):
agent_extractor.parse_json(response)
def test_parse_jsonl_without_code_blocks(self, agent_extractor):
"""Test JSONL parsing without code blocks"""
response = '''{"type": "definition", "entity": "ML", "definition": "Machine Learning"}
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}'''
result = agent_extractor.parse_jsonl(response)
assert len(result) == 2
assert result[0]["entity"] == "ML"
assert result[1]["entity"] == "AI"
def test_parse_jsonl_invalid_lines_skipped(self, agent_extractor):
"""Test JSONL parsing skips invalid lines gracefully"""
response = '''{"type": "definition", "entity": "Valid", "definition": "Valid def"}
This is not JSON at all
{"type": "definition", "entity": "Also Valid", "definition": "Another def"}'''
result = agent_extractor.parse_jsonl(response)
# Should get 2 valid objects, skipping the invalid line
assert len(result) == 2
assert result[0]["entity"] == "Valid"
assert result[1]["entity"] == "Also Valid"
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
"""Test JSONL parsing handles truncated responses"""
# Simulates output cut off mid-line
response = '''{"type": "definition", "entity": "Complete", "definition": "Full def"}
{"type": "definition", "entity": "Trunca'''
result = agent_extractor.parse_jsonl(response)
# Should get 1 valid object, the truncated line is skipped
assert len(result) == 1
assert result[0]["entity"] == "Complete"
def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata):
"""Test processing of definition data"""
data = {
"definitions": [
{
"entity": "Machine Learning",
"definition": "A subset of AI that enables learning from data."
}
],
"relationships": []
}
data = [
{
"type": "definition",
"entity": "Machine Learning",
"definition": "A subset of AI that enables learning from data."
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check entity label triple
label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None)
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
assert label_triple is not None
assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert label_triple.s.is_uri == True
assert label_triple.o.is_uri == False
assert label_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert label_triple.s.type == IRI
assert label_triple.o.type == LITERAL
# Check definition triple
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
assert def_triple is not None
assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert def_triple.o.value == "A subset of AI that enables learning from data."
# Check subject-of triple
subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None)
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
assert subject_of_triple is not None
assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert subject_of_triple.o.value == "doc123"
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert subject_of_triple.o.iri == "doc123"
# Check entity context
assert len(entity_contexts) == 1
assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert entity_contexts[0].context == "A subset of AI that enables learning from data."
def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata):
"""Test processing of relationship data"""
data = {
"definitions": [],
"relationships": [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
}
]
}
data = [
{
"type": "relationship",
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that subject, predicate, and object labels are created
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of"
# Find label triples
subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None)
subject_label = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == RDF_LABEL), None)
assert subject_label is not None
assert subject_label.o.value == "Machine Learning"
predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None)
predicate_label = next((t for t in triples if t.s.iri == predicate_uri and t.p.iri == RDF_LABEL), None)
assert predicate_label is not None
assert predicate_label.o.value == "is_subset_of"
# Check main relationship triple
# NOTE: Current implementation has bugs:
# 1. Uses data.get("object-entity") instead of rel.get("object-entity")
# 2. Sets object_value to predicate_uri instead of actual object URI
# This test documents the current buggy behavior
rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None)
# Check main relationship triple
object_uri = f"{TRUSTGRAPH_ENTITIES}Artificial%20Intelligence"
rel_triple = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == predicate_uri), None)
assert rel_triple is not None
# Due to bug, object value is set to predicate_uri
assert rel_triple.o.value == predicate_uri
assert rel_triple.o.iri == object_uri
assert rel_triple.o.type == IRI
# Check subject-of relationships
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
"""Test processing of relationships with literal objects"""
data = {
"definitions": [],
"relationships": [
{
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
}
data = [
{
"type": "relationship",
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that object labels are not created for literal objects
object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"]
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
# Based on the code logic, it should not create object labels for non-entity objects
# But there might be a bug in the original implementation
@ -263,75 +263,62 @@ class TestAgentKgExtractor:
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
# Check that we have both definition and relationship triples
definition_triples = [t for t in triples if t.p.value == DEFINITION]
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
assert len(definition_triples) == 2 # Two definitions
# Check entity contexts are created for definitions
assert len(entity_contexts) == 2
entity_uris = [ec.entity.value for ec in entity_contexts]
entity_uris = [ec.entity.iri for ec in entity_contexts]
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
"""Test processing when metadata has no ID"""
metadata = Metadata(id=None, metadata=[])
data = {
"definitions": [
{"entity": "Test Entity", "definition": "Test definition"}
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of relationships when no metadata ID
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
# Should still create entity contexts
assert len(entity_contexts) == 1
def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata):
"""Test processing of empty extraction data"""
data = {"definitions": [], "relationships": []}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Should only have metadata triples
assert len(entity_contexts) == 0
# Triples should only contain metadata triples if any
data = []
def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata):
"""Test processing data with missing keys"""
# Test missing definitions key
data = {"relationships": []}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Should have no entity contexts
assert len(entity_contexts) == 0
# Test missing relationships key
data = {"definitions": []}
# Triples should be empty
assert len(triples) == 0
def test_process_extraction_data_unknown_types_ignored(self, agent_extractor, sample_metadata):
"""Test processing data with unknown type values"""
data = [
{"type": "definition", "entity": "Valid", "definition": "Valid def"},
{"type": "unknown_type", "foo": "bar"}, # Unknown type - should be ignored
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
assert len(entity_contexts) == 0
# Test completely missing keys
data = {}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
assert len(entity_contexts) == 0
# Should process valid items and ignore unknown types
assert len(entity_contexts) == 1 # Only the definition creates entity context
def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata):
"""Test processing data with malformed entries"""
# Test definition missing required fields
data = {
"definitions": [
{"entity": "Test"}, # Missing definition
{"definition": "Test def"} # Missing entity
],
"relationships": [
{"subject": "A", "predicate": "rel"}, # Missing object
{"subject": "B", "object": "C"} # Missing predicate
]
}
# Test items missing required fields - should raise KeyError
data = [
{"type": "definition", "entity": "Test"}, # Missing definition
]
# Should handle gracefully or raise appropriate errors
with pytest.raises(KeyError):
agent_extractor.process_extraction_data(data, sample_metadata)
@ -340,17 +327,17 @@ class TestAgentKgExtractor:
async def test_emit_triples(self, agent_extractor, sample_metadata):
"""Test emitting triples to publisher"""
mock_publisher = AsyncMock()
test_triples = [
Triple(
s=Value(value="test:subject", is_uri=True),
p=Value(value="test:predicate", is_uri=True),
o=Value(value="test object", is_uri=False)
s=Term(type=IRI, iri="test:subject"),
p=Term(type=IRI, iri="test:predicate"),
o=Term(type=LITERAL, value="test object")
)
]
await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples)
mock_publisher.send.assert_called_once()
sent_triples = mock_publisher.send.call_args[0][0]
assert isinstance(sent_triples, Triples)
@ -361,22 +348,22 @@ class TestAgentKgExtractor:
# Note: metadata.metadata is now empty array in the new implementation
assert sent_triples.metadata.metadata == []
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.value == "test:subject"
assert sent_triples.triples[0].s.iri == "test:subject"
@pytest.mark.asyncio
async def test_emit_entity_contexts(self, agent_extractor, sample_metadata):
"""Test emitting entity contexts to publisher"""
mock_publisher = AsyncMock()
test_contexts = [
EntityContext(
entity=Value(value="test:entity", is_uri=True),
entity=Term(type=IRI, iri="test:entity"),
context="Test context"
)
]
await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts)
mock_publisher.send.assert_called_once()
sent_contexts = mock_publisher.send.call_args[0][0]
assert isinstance(sent_contexts, EntityContexts)
@ -387,7 +374,7 @@ class TestAgentKgExtractor:
# Note: metadata.metadata is now empty array in the new implementation
assert sent_contexts.metadata.metadata == []
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.value == "test:entity"
assert sent_contexts.entities[0].entity.iri == "test:entity"
def test_agent_extractor_initialization_params(self):
"""Test agent extractor parameter validation"""

View file

@ -11,7 +11,7 @@ import urllib.parse
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
@ -32,11 +32,11 @@ class TestAgentKgExtractionEdgeCases:
# Set up the methods we want to test
extractor.to_uri = real_extractor.to_uri
extractor.parse_json = real_extractor.parse_json
extractor.parse_jsonl = real_extractor.parse_jsonl
extractor.process_extraction_data = real_extractor.process_extraction_data
extractor.emit_triples = real_extractor.emit_triples
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
return extractor
def test_to_uri_special_characters(self, agent_extractor):
@ -85,146 +85,116 @@ class TestAgentKgExtractionEdgeCases:
# Verify the URI is properly encoded
assert unicode_text not in uri # Original unicode should be encoded
def test_parse_json_whitespace_variations(self, agent_extractor):
"""Test JSON parsing with various whitespace patterns"""
# Test JSON with different whitespace patterns
def test_parse_jsonl_whitespace_variations(self, agent_extractor):
"""Test JSONL parsing with various whitespace patterns"""
# Test JSONL with different whitespace patterns
test_cases = [
# Extra whitespace around code blocks
" ```json\n{\"test\": true}\n``` ",
# Tabs and mixed whitespace
"\t\t```json\n\t{\"test\": true}\n\t```\t",
# Multiple newlines
"\n\n\n```json\n\n{\"test\": true}\n\n```\n\n",
# JSON without code blocks but with whitespace
" {\"test\": true} ",
# Mixed line endings
"```json\r\n{\"test\": true}\r\n```",
' ```json\n{"type": "definition", "entity": "test", "definition": "def"}\n``` ',
# Multiple newlines between lines
'{"type": "definition", "entity": "A", "definition": "def A"}\n\n\n{"type": "definition", "entity": "B", "definition": "def B"}',
# JSONL without code blocks but with whitespace
' {"type": "definition", "entity": "test", "definition": "def"} ',
]
for response in test_cases:
result = agent_extractor.parse_json(response)
assert result == {"test": True}
def test_parse_json_code_block_variations(self, agent_extractor):
"""Test JSON parsing with different code block formats"""
for response in test_cases:
result = agent_extractor.parse_jsonl(response)
assert len(result) >= 1
assert result[0].get("type") == "definition"
def test_parse_jsonl_code_block_variations(self, agent_extractor):
"""Test JSONL parsing with different code block formats"""
test_cases = [
# Standard json code block
"```json\n{\"valid\": true}\n```",
'```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# jsonl code block
'```jsonl\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# Code block without language
"```\n{\"valid\": true}\n```",
# Uppercase JSON
"```JSON\n{\"valid\": true}\n```",
# Mixed case
"```Json\n{\"valid\": true}\n```",
# Multiple code blocks (should take first one)
"```json\n{\"first\": true}\n```\n```json\n{\"second\": true}\n```",
# Code block with extra content
"Here's the result:\n```json\n{\"valid\": true}\n```\nDone!",
'```\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# Code block with extra content before/after
'Here\'s the result:\n```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```\nDone!',
]
for i, response in enumerate(test_cases):
try:
result = agent_extractor.parse_json(response)
assert result.get("valid") == True or result.get("first") == True
except json.JSONDecodeError:
# Some cases may fail due to regex extraction issues
# This documents current behavior - the regex may not match all cases
print(f"Case {i} failed JSON parsing: {response[:50]}...")
pass
result = agent_extractor.parse_jsonl(response)
assert len(result) >= 1, f"Case {i} failed"
assert result[0].get("entity") == "A"
def test_parse_json_malformed_code_blocks(self, agent_extractor):
"""Test JSON parsing with malformed code block formats"""
# These should still work by falling back to treating entire text as JSON
test_cases = [
# Unclosed code block
"```json\n{\"test\": true}",
# No opening backticks
"{\"test\": true}\n```",
# Wrong number of backticks
"`json\n{\"test\": true}\n`",
# Nested backticks (should handle gracefully)
"```json\n{\"code\": \"```\", \"test\": true}\n```",
]
for response in test_cases:
try:
result = agent_extractor.parse_json(response)
assert "test" in result # Should successfully parse
except json.JSONDecodeError:
# This is also acceptable for malformed cases
pass
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
"""Test JSONL parsing with truncated responses"""
# Simulates LLM output being cut off mid-line
response = '''{"type": "definition", "entity": "Complete1", "definition": "Full definition"}
{"type": "definition", "entity": "Complete2", "definition": "Another full def"}
{"type": "definition", "entity": "Trunca'''
def test_parse_json_large_responses(self, agent_extractor):
"""Test JSON parsing with very large responses"""
# Create a large JSON structure
large_data = {
"definitions": [
{
"entity": f"Entity {i}",
"definition": f"Definition {i} " + "with more content " * 100
}
for i in range(100)
],
"relationships": [
{
"subject": f"Subject {i}",
"predicate": f"predicate_{i}",
"object": f"Object {i}",
"object-entity": i % 2 == 0
}
for i in range(50)
]
}
large_json_str = json.dumps(large_data)
response = f"```json\n{large_json_str}\n```"
result = agent_extractor.parse_json(response)
assert len(result["definitions"]) == 100
assert len(result["relationships"]) == 50
assert result["definitions"][0]["entity"] == "Entity 0"
result = agent_extractor.parse_jsonl(response)
# Should get 2 valid objects, the truncated line is skipped
assert len(result) == 2
assert result[0]["entity"] == "Complete1"
assert result[1]["entity"] == "Complete2"
def test_parse_jsonl_large_responses(self, agent_extractor):
"""Test JSONL parsing with very large responses"""
# Create a large JSONL response
lines = []
for i in range(100):
lines.append(json.dumps({
"type": "definition",
"entity": f"Entity {i}",
"definition": f"Definition {i} " + "with more content " * 100
}))
for i in range(50):
lines.append(json.dumps({
"type": "relationship",
"subject": f"Subject {i}",
"predicate": f"predicate_{i}",
"object": f"Object {i}",
"object-entity": i % 2 == 0
}))
response = f"```json\n{chr(10).join(lines)}\n```"
result = agent_extractor.parse_jsonl(response)
definitions = [r for r in result if r.get("type") == "definition"]
relationships = [r for r in result if r.get("type") == "relationship"]
assert len(definitions) == 100
assert len(relationships) == 50
assert definitions[0]["entity"] == "Entity 0"
def test_process_extraction_data_empty_metadata(self, agent_extractor):
"""Test processing with empty or minimal metadata"""
# Test with None metadata - may not raise AttributeError depending on implementation
try:
triples, contexts = agent_extractor.process_extraction_data(
{"definitions": [], "relationships": []},
None
)
triples, contexts = agent_extractor.process_extraction_data([], None)
# If it doesn't raise, check the results
assert len(triples) == 0
assert len(contexts) == 0
except (AttributeError, TypeError):
# This is expected behavior when metadata is None
pass
# Test with metadata without ID
metadata = Metadata(id=None, metadata=[])
triples, contexts = agent_extractor.process_extraction_data(
{"definitions": [], "relationships": []},
metadata
)
triples, contexts = agent_extractor.process_extraction_data([], metadata)
assert len(triples) == 0
assert len(contexts) == 0
# Test with metadata with empty string ID
metadata = Metadata(id="", metadata=[])
data = {
"definitions": [{"entity": "Test", "definition": "Test def"}],
"relationships": []
}
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of triples when ID is empty string
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
def test_process_extraction_data_special_entity_names(self, agent_extractor):
"""Test processing with special characters in entity names"""
metadata = Metadata(id="doc123", metadata=[])
special_entities = [
"Entity with spaces",
"Entity & Co.",
@ -237,71 +207,62 @@ class TestAgentKgExtractionEdgeCases:
"Quotes: \"test\"",
"Parentheses: (test)",
]
data = {
"definitions": [
{"entity": entity, "definition": f"Definition for {entity}"}
for entity in special_entities
],
"relationships": []
}
data = [
{"type": "definition", "entity": entity, "definition": f"Definition for {entity}"}
for entity in special_entities
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Verify all entities were processed
assert len(contexts) == len(special_entities)
# Verify URIs were properly encoded
for i, entity in enumerate(special_entities):
expected_uri = f"{TRUSTGRAPH_ENTITIES}{urllib.parse.quote(entity)}"
assert contexts[i].entity.value == expected_uri
assert contexts[i].entity.iri == expected_uri
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
"""Test processing with very long entity definitions"""
metadata = Metadata(id="doc123", metadata=[])
# Create very long definition
long_definition = "This is a very long definition. " * 1000
data = {
"definitions": [
{"entity": "Test Entity", "definition": long_definition}
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle long definitions without issues
assert len(contexts) == 1
assert contexts[0].context == long_definition
# Find definition triple
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
assert def_triple is not None
assert def_triple.o.value == long_definition
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
"""Test processing with duplicate entity names"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{"entity": "Machine Learning", "definition": "First definition"},
{"entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
{"entity": "AI", "definition": "AI definition"},
{"entity": "AI", "definition": "Another AI definition"}, # Duplicate
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
{"type": "definition", "entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
{"type": "definition", "entity": "AI", "definition": "AI definition"},
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should process all entries (including duplicates)
assert len(contexts) == 4
# Check that both definitions for "Machine Learning" are present
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.value]
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.iri]
assert len(ml_contexts) == 2
assert ml_contexts[0].context == "First definition"
assert ml_contexts[1].context == "Second definition"
@ -309,49 +270,44 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_empty_strings(self, agent_extractor):
"""Test processing with empty strings in data"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{"entity": "", "definition": "Definition for empty entity"},
{"entity": "Valid Entity", "definition": ""},
{"entity": " ", "definition": " "}, # Whitespace only
],
"relationships": [
{"subject": "", "predicate": "test", "object": "test", "object-entity": True},
{"subject": "test", "predicate": "", "object": "test", "object-entity": True},
{"subject": "test", "predicate": "test", "object": "", "object-entity": True},
]
}
data = [
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
{"type": "definition", "entity": "Valid Entity", "definition": ""},
{"type": "definition", "entity": " ", "definition": " "}, # Whitespace only
{"type": "relationship", "subject": "", "predicate": "test", "object": "test", "object-entity": True},
{"type": "relationship", "subject": "test", "predicate": "", "object": "test", "object-entity": True},
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle empty strings by creating URIs (even if empty)
assert len(contexts) == 3
# Empty entity should create empty URI after encoding
empty_entity_context = next((ec for ec in contexts if ec.entity.value == TRUSTGRAPH_ENTITIES), None)
empty_entity_context = next((ec for ec in contexts if ec.entity.iri == TRUSTGRAPH_ENTITIES), None)
assert empty_entity_context is not None
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
"""Test processing when definitions contain JSON-like strings"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{
"entity": "JSON Entity",
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
},
{
"entity": "Array Entity",
"definition": 'Contains array: [1, 2, 3, "string"]'
}
],
"relationships": []
}
data = [
{
"type": "definition",
"entity": "JSON Entity",
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
},
{
"type": "definition",
"entity": "Array Entity",
"definition": 'Contains array: [1, 2, 3, "string"]'
}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle JSON strings in definitions without parsing them
assert len(contexts) == 2
assert '{"key": "value"' in contexts[0].context
@ -360,32 +316,29 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
"""Test processing with various boolean values for object-entity"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [],
"relationships": [
# Explicit True
{"subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
# Explicit False
{"subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
# Missing object-entity (should default to True based on code)
{"subject": "A", "predicate": "rel3", "object": "C"},
# String "true" (should be treated as truthy)
{"subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
# String "false" (should be treated as truthy in Python)
{"subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
# Number 0 (falsy)
{"subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
# Number 1 (truthy)
{"subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
]
}
data = [
# Explicit True
{"type": "relationship", "subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
# Explicit False
{"type": "relationship", "subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
# Missing object-entity (should default to True based on code)
{"type": "relationship", "subject": "A", "predicate": "rel3", "object": "C"},
# String "true" (should be treated as truthy)
{"type": "relationship", "subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
# String "false" (should be treated as truthy in Python)
{"type": "relationship", "subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
# Number 0 (falsy)
{"type": "relationship", "subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
# Number 1 (truthy)
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should process all relationships
# Note: The current implementation has some logic issues that these tests document
assert len([t for t in triples if t.p.value != RDF_LABEL and t.p.value != SUBJECT_OF]) >= 7
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
@pytest.mark.asyncio
async def test_emit_empty_collections(self, agent_extractor):
@ -437,41 +390,40 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
"""Test performance with large extraction datasets"""
metadata = Metadata(id="large-doc", metadata=[])
# Create large dataset
# Create large dataset in JSONL format
num_definitions = 1000
num_relationships = 2000
large_data = {
"definitions": [
{
"entity": f"Entity_{i:04d}",
"definition": f"Definition for entity {i} with some detailed explanation."
}
for i in range(num_definitions)
],
"relationships": [
{
"subject": f"Entity_{i % num_definitions:04d}",
"predicate": f"predicate_{i % 10}",
"object": f"Entity_{(i + 1) % num_definitions:04d}",
"object-entity": True
}
for i in range(num_relationships)
]
}
large_data = [
{
"type": "definition",
"entity": f"Entity_{i:04d}",
"definition": f"Definition for entity {i} with some detailed explanation."
}
for i in range(num_definitions)
] + [
{
"type": "relationship",
"subject": f"Entity_{i % num_definitions:04d}",
"predicate": f"predicate_{i % 10}",
"object": f"Entity_{(i + 1) % num_definitions:04d}",
"object-entity": True
}
for i in range(num_relationships)
]
import time
start_time = time.time()
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
end_time = time.time()
processing_time = end_time - start_time
# Should complete within reasonable time (adjust threshold as needed)
assert processing_time < 10.0 # 10 seconds threshold
# Verify results
assert len(contexts) == num_definitions
# Triples include labels, definitions, relationships, and subject-of relations

View file

@ -7,7 +7,7 @@ processing graph structures, and performing graph operations.
import pytest
from unittest.mock import Mock
from .conftest import Triple, Value, Metadata
from .conftest import Triple, Metadata
from collections import defaultdict, deque

View file

@ -76,7 +76,7 @@ def cities_schema():
def validator():
"""Create a mock processor with just the validation method"""
from unittest.mock import MagicMock
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
# Create a mock processor
mock_processor = MagicMock()

View file

@ -2,13 +2,13 @@
Unit tests for triple construction logic
Tests the core business logic for constructing RDF triples from extracted
entities and relationships, including URI generation, Value object creation,
entities and relationships, including URI generation, Term object creation,
and triple validation.
"""
import pytest
from unittest.mock import Mock
from .conftest import Triple, Triples, Value, Metadata
from .conftest import Triple, Triples, Term, Metadata, IRI, LITERAL
import re
import hashlib
@ -48,80 +48,82 @@ class TestTripleConstructionLogic:
generated_uri = generate_uri(text, entity_type)
assert generated_uri == expected_uri, f"URI generation failed for '{text}'"
def test_value_object_creation(self):
"""Test creation of Value objects for subjects, predicates, and objects"""
def test_term_object_creation(self):
"""Test creation of Term objects for subjects, predicates, and objects"""
# Arrange
def create_value_object(text, is_uri, value_type=""):
return Value(
value=text,
is_uri=is_uri,
type=value_type
)
def create_term_object(text, is_uri, datatype=""):
if is_uri:
return Term(type=IRI, iri=text)
else:
return Term(type=LITERAL, value=text, datatype=datatype if datatype else None)
test_cases = [
("http://trustgraph.ai/kg/person/john-smith", True, ""),
("John Smith", False, "string"),
("42", False, "integer"),
("http://schema.org/worksFor", True, "")
]
# Act & Assert
for value_text, is_uri, value_type in test_cases:
value_obj = create_value_object(value_text, is_uri, value_type)
assert isinstance(value_obj, Value)
assert value_obj.value == value_text
assert value_obj.is_uri == is_uri
assert value_obj.type == value_type
for value_text, is_uri, datatype in test_cases:
term_obj = create_term_object(value_text, is_uri, datatype)
assert isinstance(term_obj, Term)
if is_uri:
assert term_obj.type == IRI
assert term_obj.iri == value_text
else:
assert term_obj.type == LITERAL
assert term_obj.value == value_text
def test_triple_construction_from_relationship(self):
"""Test constructing Triple objects from relationships"""
# Arrange
relationship = {
"subject": "John Smith",
"predicate": "works_for",
"predicate": "works_for",
"object": "OpenAI",
"subject_type": "PERSON",
"object_type": "ORG"
}
def construct_triple(relationship, uri_base="http://trustgraph.ai/kg"):
# Generate URIs
subject_uri = f"{uri_base}/person/{relationship['subject'].lower().replace(' ', '-')}"
object_uri = f"{uri_base}/org/{relationship['object'].lower().replace(' ', '-')}"
# Map predicate to schema.org URI
predicate_mappings = {
"works_for": "http://schema.org/worksFor",
"located_in": "http://schema.org/location",
"developed": "http://schema.org/creator"
}
predicate_uri = predicate_mappings.get(relationship["predicate"],
predicate_uri = predicate_mappings.get(relationship["predicate"],
f"{uri_base}/predicate/{relationship['predicate']}")
# Create Value objects
subject_value = Value(value=subject_uri, is_uri=True, type="")
predicate_value = Value(value=predicate_uri, is_uri=True, type="")
object_value = Value(value=object_uri, is_uri=True, type="")
# Create Term objects
subject_term = Term(type=IRI, iri=subject_uri)
predicate_term = Term(type=IRI, iri=predicate_uri)
object_term = Term(type=IRI, iri=object_uri)
# Create Triple
return Triple(
s=subject_value,
p=predicate_value,
o=object_value
s=subject_term,
p=predicate_term,
o=object_term
)
# Act
triple = construct_triple(relationship)
# Assert
assert isinstance(triple, Triple)
assert triple.s.value == "http://trustgraph.ai/kg/person/john-smith"
assert triple.s.is_uri is True
assert triple.p.value == "http://schema.org/worksFor"
assert triple.p.is_uri is True
assert triple.o.value == "http://trustgraph.ai/kg/org/openai"
assert triple.o.is_uri is True
assert triple.s.iri == "http://trustgraph.ai/kg/person/john-smith"
assert triple.s.type == IRI
assert triple.p.iri == "http://schema.org/worksFor"
assert triple.p.type == IRI
assert triple.o.iri == "http://trustgraph.ai/kg/org/openai"
assert triple.o.type == IRI
def test_literal_value_handling(self):
"""Test handling of literal values vs URI values"""
@ -132,10 +134,10 @@ class TestTripleConstructionLogic:
("John Smith", "email", "john@example.com", False), # Literal email
("John Smith", "worksFor", "http://trustgraph.ai/kg/org/openai", True) # URI reference
]
def create_triple_with_literal(subject_uri, predicate, object_value, object_is_uri):
subject_val = Value(value=subject_uri, is_uri=True, type="")
subject_term = Term(type=IRI, iri=subject_uri)
# Determine predicate URI
predicate_mappings = {
"name": "http://schema.org/name",
@ -144,32 +146,37 @@ class TestTripleConstructionLogic:
"worksFor": "http://schema.org/worksFor"
}
predicate_uri = predicate_mappings.get(predicate, f"http://trustgraph.ai/kg/predicate/{predicate}")
predicate_val = Value(value=predicate_uri, is_uri=True, type="")
# Create object value with appropriate type
object_type = ""
if not object_is_uri:
predicate_term = Term(type=IRI, iri=predicate_uri)
# Create object term with appropriate type
if object_is_uri:
object_term = Term(type=IRI, iri=object_value)
else:
datatype = None
if predicate == "age":
object_type = "integer"
datatype = "integer"
elif predicate in ["name", "email"]:
object_type = "string"
object_val = Value(value=object_value, is_uri=object_is_uri, type=object_type)
return Triple(s=subject_val, p=predicate_val, o=object_val)
datatype = "string"
object_term = Term(type=LITERAL, value=object_value, datatype=datatype)
return Triple(s=subject_term, p=predicate_term, o=object_term)
# Act & Assert
for subject_uri, predicate, object_value, object_is_uri in test_data:
subject_full_uri = "http://trustgraph.ai/kg/person/john-smith"
triple = create_triple_with_literal(subject_full_uri, predicate, object_value, object_is_uri)
assert triple.o.is_uri == object_is_uri
assert triple.o.value == object_value
if object_is_uri:
assert triple.o.type == IRI
assert triple.o.iri == object_value
else:
assert triple.o.type == LITERAL
assert triple.o.value == object_value
if predicate == "age":
assert triple.o.type == "integer"
assert triple.o.datatype == "integer"
elif predicate in ["name", "email"]:
assert triple.o.type == "string"
assert triple.o.datatype == "string"
def test_namespace_management(self):
"""Test namespace prefix management and expansion"""
@ -216,63 +223,74 @@ class TestTripleConstructionLogic:
def test_triple_validation(self):
"""Test triple validation rules"""
# Arrange
def get_term_value(term):
"""Extract value from a Term"""
if term.type == IRI:
return term.iri
else:
return term.value
def validate_triple(triple):
errors = []
# Check required components
if not triple.s or not triple.s.value:
s_val = get_term_value(triple.s) if triple.s else None
p_val = get_term_value(triple.p) if triple.p else None
o_val = get_term_value(triple.o) if triple.o else None
if not triple.s or not s_val:
errors.append("Missing or empty subject")
if not triple.p or not triple.p.value:
if not triple.p or not p_val:
errors.append("Missing or empty predicate")
if not triple.o or not triple.o.value:
if not triple.o or not o_val:
errors.append("Missing or empty object")
# Check URI validity for URI values
uri_pattern = r'^https?://[^\s/$.?#].[^\s]*$'
if triple.s.is_uri and not re.match(uri_pattern, triple.s.value):
if triple.s.type == IRI and not re.match(uri_pattern, triple.s.iri or ""):
errors.append("Invalid subject URI format")
if triple.p.is_uri and not re.match(uri_pattern, triple.p.value):
if triple.p.type == IRI and not re.match(uri_pattern, triple.p.iri or ""):
errors.append("Invalid predicate URI format")
if triple.o.is_uri and not re.match(uri_pattern, triple.o.value):
if triple.o.type == IRI and not re.match(uri_pattern, triple.o.iri or ""):
errors.append("Invalid object URI format")
# Predicates should typically be URIs
if not triple.p.is_uri:
if triple.p.type != IRI:
errors.append("Predicate should be a URI")
return len(errors) == 0, errors
# Test valid triple
valid_triple = Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John Smith", is_uri=False, type="string")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John Smith", datatype="string")
)
# Test invalid triples
invalid_triples = [
Triple(s=Value(value="", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John", is_uri=False, type="")), # Empty subject
Triple(s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="name", is_uri=False, type=""), # Non-URI predicate
o=Value(value="John", is_uri=False, type="")),
Triple(s=Value(value="invalid-uri", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John", is_uri=False, type="")) # Invalid URI format
Triple(s=Term(type=IRI, iri=""),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John")), # Empty subject
Triple(s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=LITERAL, value="name"), # Non-URI predicate
o=Term(type=LITERAL, value="John")),
Triple(s=Term(type=IRI, iri="invalid-uri"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John")) # Invalid URI format
]
# Act & Assert
is_valid, errors = validate_triple(valid_triple)
assert is_valid, f"Valid triple failed validation: {errors}"
for invalid_triple in invalid_triples:
is_valid, errors = validate_triple(invalid_triple)
assert not is_valid, f"Invalid triple passed validation: {invalid_triple}"
@ -286,97 +304,97 @@ class TestTripleConstructionLogic:
{"text": "OpenAI", "type": "ORG"},
{"text": "San Francisco", "type": "PLACE"}
]
relationships = [
{"subject": "John Smith", "predicate": "works_for", "object": "OpenAI"},
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"}
]
def construct_triple_batch(entities, relationships, document_id="doc-1"):
triples = []
# Create type triples for entities
for entity in entities:
entity_uri = f"http://trustgraph.ai/kg/{entity['type'].lower()}/{entity['text'].lower().replace(' ', '-')}"
type_uri = f"http://trustgraph.ai/kg/type/{entity['type']}"
type_triple = Triple(
s=Value(value=entity_uri, is_uri=True, type=""),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True, type=""),
o=Value(value=type_uri, is_uri=True, type="")
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri=type_uri)
)
triples.append(type_triple)
# Create relationship triples
for rel in relationships:
subject_uri = f"http://trustgraph.ai/kg/entity/{rel['subject'].lower().replace(' ', '-')}"
object_uri = f"http://trustgraph.ai/kg/entity/{rel['object'].lower().replace(' ', '-')}"
predicate_uri = f"http://schema.org/{rel['predicate'].replace('_', '')}"
rel_triple = Triple(
s=Value(value=subject_uri, is_uri=True, type=""),
p=Value(value=predicate_uri, is_uri=True, type=""),
o=Value(value=object_uri, is_uri=True, type="")
s=Term(type=IRI, iri=subject_uri),
p=Term(type=IRI, iri=predicate_uri),
o=Term(type=IRI, iri=object_uri)
)
triples.append(rel_triple)
return triples
# Act
triples = construct_triple_batch(entities, relationships)
# Assert
assert len(triples) == len(entities) + len(relationships) # Type triples + relationship triples
# Check that all triples are valid Triple objects
for triple in triples:
assert isinstance(triple, Triple)
assert triple.s.value != ""
assert triple.p.value != ""
assert triple.o.value != ""
assert triple.s.iri != ""
assert triple.p.iri != ""
assert triple.o.iri != ""
def test_triples_batch_object_creation(self):
"""Test creating Triples batch objects with metadata"""
# Arrange
sample_triples = [
Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John Smith", is_uri=False, type="string")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John Smith", datatype="string")
),
Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/worksFor", is_uri=True, type=""),
o=Value(value="http://trustgraph.ai/kg/org/openai", is_uri=True, type="")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/worksFor"),
o=Term(type=IRI, iri="http://trustgraph.ai/kg/org/openai")
)
]
metadata = Metadata(
id="test-doc-123",
user="test_user",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
triples_batch = Triples(
metadata=metadata,
triples=sample_triples
)
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2
# Check that triples are properly embedded
for triple in triples_batch.triples:
assert isinstance(triple, Triple)
assert isinstance(triple.s, Value)
assert isinstance(triple.p, Value)
assert isinstance(triple.o, Value)
assert isinstance(triple.s, Term)
assert isinstance(triple.p, Term)
assert isinstance(triple.o, Term)
def test_uri_collision_handling(self):
"""Test handling of URI collisions and duplicate detection"""

View file

@ -339,7 +339,250 @@ class TestPromptManager:
"""Test PromptManager with minimal configuration"""
pm = PromptManager()
pm.load_config({}) # Empty config
assert pm.config.system_template == "Be helpful." # Default system
assert pm.terms == {} # Default empty terms
assert len(pm.prompts) == 0
assert len(pm.prompts) == 0
@pytest.mark.unit
class TestPromptManagerJsonl:
"""Unit tests for PromptManager JSONL functionality"""
@pytest.fixture
def jsonl_config(self):
"""Configuration with JSONL response type prompts"""
return {
"system": json.dumps("You are an extraction assistant."),
"template-index": json.dumps(["extract_simple", "extract_with_schema", "extract_mixed"]),
"template.extract_simple": json.dumps({
"prompt": "Extract entities from: {{ text }}",
"response-type": "jsonl"
}),
"template.extract_with_schema": json.dumps({
"prompt": "Extract definitions from: {{ text }}",
"response-type": "jsonl",
"schema": {
"type": "object",
"properties": {
"entity": {"type": "string"},
"definition": {"type": "string"}
},
"required": ["entity", "definition"]
}
}),
"template.extract_mixed": json.dumps({
"prompt": "Extract knowledge from: {{ text }}",
"response-type": "jsonl",
"schema": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {"const": "definition"},
"entity": {"type": "string"},
"definition": {"type": "string"}
},
"required": ["type", "entity", "definition"]
},
{
"type": "object",
"properties": {
"type": {"const": "relationship"},
"subject": {"type": "string"},
"predicate": {"type": "string"},
"object": {"type": "string"}
},
"required": ["type", "subject", "predicate", "object"]
}
]
}
})
}
@pytest.fixture
def prompt_manager(self, jsonl_config):
"""Create a PromptManager with JSONL configuration"""
pm = PromptManager()
pm.load_config(jsonl_config)
return pm
def test_parse_jsonl_basic(self, prompt_manager):
"""Test basic JSONL parsing"""
text = '{"entity": "cat", "definition": "A small furry animal"}\n{"entity": "dog", "definition": "A loyal pet"}'
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
assert result[0]["entity"] == "cat"
assert result[1]["entity"] == "dog"
def test_parse_jsonl_with_empty_lines(self, prompt_manager):
"""Test JSONL parsing skips empty lines"""
text = '{"entity": "cat"}\n\n\n{"entity": "dog"}\n'
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
def test_parse_jsonl_with_markdown_fences(self, prompt_manager):
"""Test JSONL parsing strips markdown code fences"""
text = '''```json
{"entity": "cat", "definition": "A furry animal"}
{"entity": "dog", "definition": "A loyal pet"}
```'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
assert result[0]["entity"] == "cat"
assert result[1]["entity"] == "dog"
def test_parse_jsonl_with_jsonl_fence(self, prompt_manager):
"""Test JSONL parsing strips jsonl-marked code fences"""
text = '''```jsonl
{"entity": "cat"}
{"entity": "dog"}
```'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
def test_parse_jsonl_truncation_resilience(self, prompt_manager):
"""Test JSONL parsing handles truncated final line"""
text = '{"entity": "cat", "definition": "Complete"}\n{"entity": "dog", "defi'
result = prompt_manager.parse_jsonl(text)
# Should get the first valid object, skip the truncated one
assert len(result) == 1
assert result[0]["entity"] == "cat"
def test_parse_jsonl_invalid_lines_skipped(self, prompt_manager):
"""Test JSONL parsing skips invalid JSON lines"""
text = '''{"entity": "valid1"}
not json at all
{"entity": "valid2"}
{broken json
{"entity": "valid3"}'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 3
assert result[0]["entity"] == "valid1"
assert result[1]["entity"] == "valid2"
assert result[2]["entity"] == "valid3"
def test_parse_jsonl_empty_input(self, prompt_manager):
"""Test JSONL parsing with empty input"""
result = prompt_manager.parse_jsonl("")
assert result == []
result = prompt_manager.parse_jsonl("\n\n\n")
assert result == []
@pytest.mark.asyncio
async def test_invoke_jsonl_response(self, prompt_manager):
"""Test invoking a prompt with JSONL response"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"entity": "photosynthesis", "definition": "Plant process"}\n{"entity": "mitosis", "definition": "Cell division"}'
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Biology text"},
mock_llm
)
assert isinstance(result, list)
assert len(result) == 2
assert result[0]["entity"] == "photosynthesis"
assert result[1]["entity"] == "mitosis"
@pytest.mark.asyncio
async def test_invoke_jsonl_with_schema_validation(self, prompt_manager):
"""Test JSONL response with schema validation"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"entity": "cat", "definition": "A pet"}\n{"entity": "dog", "definition": "Another pet"}'
result = await prompt_manager.invoke(
"extract_with_schema",
{"text": "Animal text"},
mock_llm
)
assert len(result) == 2
assert all("entity" in obj and "definition" in obj for obj in result)
@pytest.mark.asyncio
async def test_invoke_jsonl_schema_filters_invalid(self, prompt_manager):
"""Test JSONL schema validation filters out invalid objects"""
mock_llm = AsyncMock()
# Second object is missing required 'definition' field
mock_llm.return_value = '{"entity": "valid", "definition": "Has both fields"}\n{"entity": "invalid_missing_definition"}\n{"entity": "also_valid", "definition": "Complete"}'
result = await prompt_manager.invoke(
"extract_with_schema",
{"text": "Test text"},
mock_llm
)
# Only the two valid objects should be returned
assert len(result) == 2
assert result[0]["entity"] == "valid"
assert result[1]["entity"] == "also_valid"
@pytest.mark.asyncio
async def test_invoke_jsonl_mixed_types(self, prompt_manager):
"""Test JSONL with discriminated union schema (oneOf)"""
mock_llm = AsyncMock()
mock_llm.return_value = '''{"type": "definition", "entity": "DNA", "definition": "Genetic material"}
{"type": "relationship", "subject": "DNA", "predicate": "found_in", "object": "nucleus"}
{"type": "definition", "entity": "RNA", "definition": "Messenger molecule"}'''
result = await prompt_manager.invoke(
"extract_mixed",
{"text": "Biology text"},
mock_llm
)
assert len(result) == 3
# Check definitions
definitions = [r for r in result if r.get("type") == "definition"]
assert len(definitions) == 2
# Check relationships
relationships = [r for r in result if r.get("type") == "relationship"]
assert len(relationships) == 1
assert relationships[0]["subject"] == "DNA"
@pytest.mark.asyncio
async def test_invoke_jsonl_empty_result(self, prompt_manager):
"""Test JSONL response that yields no valid objects"""
mock_llm = AsyncMock()
mock_llm.return_value = "No JSON here at all"
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Test"},
mock_llm
)
assert result == []
@pytest.mark.asyncio
async def test_invoke_jsonl_without_schema(self, prompt_manager):
"""Test JSONL response without schema validation"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"any": "structure"}\n{"completely": "different"}'
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Test"},
mock_llm
)
assert len(result) == 2
assert result[0] == {"any": "structure"}
assert result[1] == {"completely": "different"}

View file

@ -167,7 +167,7 @@ class TestFlowClient:
expected_methods = [
'text_completion', 'agent', 'graph_rag', 'document_rag',
'graph_embeddings_query', 'embeddings', 'prompt',
'triples_query', 'objects_query'
'triples_query', 'rows_query'
]
for method in expected_methods:
@ -216,7 +216,7 @@ class TestSocketClient:
expected_methods = [
'agent', 'text_completion', 'graph_rag', 'document_rag',
'prompt', 'graph_embeddings_query', 'embeddings',
'triples_query', 'objects_query', 'mcp_tool'
'triples_query', 'rows_query', 'mcp_tool'
]
for method in expected_methods:
@ -243,7 +243,7 @@ class TestBulkClient:
'import_graph_embeddings',
'import_document_embeddings',
'import_entity_contexts',
'import_objects'
'import_rows'
]
for method in import_methods:

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Value, GraphEmbeddingsRequest
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
class TestMilvusGraphEmbeddingsQueryProcessor:
@ -68,50 +68,50 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Value objects
# Verify results are converted to Term objects
assert len(result) == 3
assert isinstance(result[0], Value)
assert result[0].value == "http://example.com/entity1"
assert result[0].is_uri is True
assert isinstance(result[1], Value)
assert result[1].value == "http://example.com/entity2"
assert result[1].is_uri is True
assert isinstance(result[2], Value)
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].is_uri is False
assert result[2].type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results are deduplicated and limited
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
@ -346,14 +346,14 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.is_uri]
uri_results = [r for r in result if r.type == IRI]
assert len(uri_results) == 2
uri_values = [r.value for r in uri_results]
uri_values = [r.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.is_uri]
literal_results = [r for r in result if not r.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
assert "literal entity text" in literal_values
@ -486,7 +486,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Value
from trustgraph.schema import Term, IRI, LITERAL
class TestPineconeGraphEmbeddingsQueryProcessor:
@ -105,27 +105,27 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
uri_entity = "http://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == uri_entity
assert value.is_uri == True
assert value.type == IRI
def test_create_value_https_uri(self, processor):
"""Test create_value method for HTTPS URI entities"""
uri_entity = "https://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == uri_entity
assert value.is_uri == True
assert value.type == IRI
def test_create_value_literal(self, processor):
"""Test create_value method for literal entities"""
literal_entity = "literal_entity"
value = processor.create_value(literal_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == literal_entity
assert value.is_uri == False
assert value.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
@ -165,11 +165,11 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
# Verify results
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].is_uri == True
assert entities[0].type == IRI
assert entities[1].value == 'entity2'
assert entities[1].is_uri == False
assert entities[1].type == LITERAL
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].is_uri == True
assert entities[2].type == IRI
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -85,10 +86,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
value = processor.create_value('http://example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'http://example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
assert hasattr(value, 'iri')
assert value.iri == 'http://example.com/entity'
assert hasattr(value, 'type')
assert value.type == IRI
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -109,10 +110,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
value = processor.create_value('https://secure.example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'https://secure.example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
assert hasattr(value, 'iri')
assert value.iri == 'https://secure.example.com/entity'
assert hasattr(value, 'type')
assert value.type == IRI
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -135,8 +136,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert hasattr(value, 'value')
assert value.value == 'regular entity name'
assert hasattr(value, 'is_uri')
assert value.is_uri == False
assert hasattr(value, 'type')
assert value.type == LITERAL
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -428,14 +429,14 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
uri_entities = [entity for entity in result if entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.value for entity in uri_entities]
uri_values = [entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
regular_entities = [entity for entity in result if entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
@ -24,9 +24,9 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="test_object"),
limit=1000
)
@ -65,8 +65,8 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=1000
)
@ -105,9 +105,9 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=Value(value="http://example.com/o", is_uri=True),
o=Term(type=IRI, iri="http://example.com/o"),
limit=1000
)
@ -145,7 +145,7 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -185,8 +185,8 @@ class TestMemgraphQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="literal"),
limit=1000
)
@ -225,7 +225,7 @@ class TestMemgraphQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=1000
)
@ -265,7 +265,7 @@ class TestMemgraphQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False),
o=Term(type=LITERAL, value="test_value"),
limit=1000
)
@ -355,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -385,7 +385,7 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -416,17 +416,17 @@ class TestMemgraphQueryUserCollectionIsolation:
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].s.iri == "http://example.com/s"
assert result[0].s.type == IRI
assert result[0].p.iri == "http://example.com/p1"
assert result[0].p.type == IRI
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
assert result[0].o.type == LITERAL
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True
assert result[1].s.iri == "http://example.com/s"
assert result[1].s.type == IRI
assert result[1].p.iri == "http://example.com/p2"
assert result[1].p.type == IRI
assert result[1].o.iri == "http://example.com/o"
assert result[1].o.type == IRI

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
@ -24,21 +24,23 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False)
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="test_object"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src"
"RETURN $src as src "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -63,23 +65,25 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=None
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest"
"RETURN dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
src="http://example.com/s",
@ -88,13 +92,14 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
"RETURN dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -118,21 +123,23 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=Value(value="http://example.com/o", is_uri=True)
o=Term(type=IRI, iri="http://example.com/o"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel"
"RETURN rel.uri as rel "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -156,23 +163,25 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest"
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
@ -194,20 +203,22 @@ class TestNeo4jQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False)
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="literal"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src"
"RETURN src.uri as src "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -232,20 +243,22 @@ class TestNeo4jQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=None
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest"
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -270,19 +283,21 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False)
o=Term(type=LITERAL, value="test_value"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel"
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -307,34 +322,37 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest"
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -355,9 +373,10 @@ class TestNeo4jQueryUserCollectionIsolation:
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
@ -384,47 +403,48 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
# Mock some results
mock_record1 = MagicMock()
mock_record1.data.return_value = {
"rel": "http://example.com/p1",
"dest": "literal_value"
}
mock_record2 = MagicMock()
mock_record2.data.return_value = {
"rel": "http://example.com/p2",
"dest": "http://example.com/o"
}
# Return results for literal query, empty for node query
mock_driver.execute_query.side_effect = [
([mock_record1], MagicMock(), MagicMock()), # Literal query
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
# Verify results are proper Triple objects
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].s.iri == "http://example.com/s"
assert result[0].s.type == IRI
assert result[0].p.iri == "http://example.com/p1"
assert result[0].p.type == IRI
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
assert result[0].o.type == LITERAL
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True
assert result[1].s.iri == "http://example.com/s"
assert result[1].s.type == IRI
assert result[1].p.iri == "http://example.com/p2"
assert result[1].p.type == IRI
assert result[1].o.iri == "http://example.com/o"
assert result[1].o.type == IRI

View file

@ -1,10 +1,11 @@
"""
Unit tests for Cassandra Objects GraphQL Query Processor
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
Tests the business logic of the GraphQL query processor including:
- GraphQL schema generation from RowSchema
- Query execution and validation
- CQL translation logic
- Schema configuration handling
- Query execution using unified rows table
- Name sanitization
- GraphQL query execution
- Message processing logic
"""
@ -12,119 +13,91 @@ import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import strawberry
from strawberry import Schema
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field
class TestObjectsGraphQLQueryLogic:
"""Test business logic without external dependencies"""
def test_get_python_type_mapping(self):
"""Test schema field type conversion to Python types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_python_type("string") == str
assert processor.get_python_type("integer") == int
assert processor.get_python_type("float") == float
assert processor.get_python_type("boolean") == bool
assert processor.get_python_type("timestamp") == str
assert processor.get_python_type("date") == str
assert processor.get_python_type("time") == str
assert processor.get_python_type("uuid") == str
# Unknown type defaults to str
assert processor.get_python_type("unknown_type") == str
def test_create_graphql_type_basic_fields(self):
"""Test GraphQL type creation for basic field types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="test_table",
description="Test table",
fields=[
Field(
name="id",
type="string",
primary=True,
required=True,
description="Primary key"
),
Field(
name="name",
type="string",
required=True,
description="Name field"
),
Field(
name="age",
type="integer",
required=False,
description="Optional age"
),
Field(
name="active",
type="boolean",
required=False,
description="Status flag"
)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test_table", schema)
# Verify type was created
assert graphql_type is not None
assert hasattr(graphql_type, '__name__')
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
class TestRowsGraphQLQueryLogic:
"""Test business logic for unified table query implementation"""
def test_sanitize_name_cassandra_compatibility(self):
"""Test name sanitization for Cassandra field names"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test field name sanitization (matches storage processor)
# Test field name sanitization (uses r_ prefix like storage processor)
assert processor.sanitize_name("simple_field") == "simple_field"
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
assert processor.sanitize_name("123_field") == "o_123_field"
assert processor.sanitize_name("123_field") == "r_123_field"
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
assert processor.sanitize_name("special!@#chars") == "special___chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
def test_sanitize_table_name(self):
"""Test table name sanitization (always gets o_ prefix)"""
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Table names always get o_ prefix
assert processor.sanitize_table("simple_table") == "o_simple_table"
assert processor.sanitize_table("Table-Name") == "o_table_name"
assert processor.sanitize_table("123table") == "o_123table"
assert processor.sanitize_table("") == "o_"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names
assert len(index_names) == 3
def test_find_matching_index_exact_match(self):
"""Test finding matching index for exact match query"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string") # Not indexed
]
)
# Filter on indexed field should return match
filters = {"category": "electronics"}
result = processor.find_matching_index(schema, filters)
assert result is not None
assert result[0] == "category"
assert result[1] == ["electronics"]
# Filter on non-indexed field should return None
filters = {"name": "test"}
result = processor.find_matching_index(schema, filters)
assert result is None
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.graphql_types = {}
processor.graphql_schema = None
processor.config_key = "schema" # Set the config key
processor.generate_graphql_schema = AsyncMock()
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
schema_config = {
"schema": {
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
})
}
}
# Process config
await processor.on_schema_config(schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
# Verify fields
id_field = next(f for f in schema.fields if f.name == "id")
assert id_field.primary is True
# The field should have been created correctly from JSON
# Let's test what we can verify - that the field has the right attributes
assert hasattr(id_field, 'required') # Has the required attribute
assert hasattr(id_field, 'primary') # Has the primary attribute
email_field = next(f for f in schema.fields if f.name == "email")
assert email_field.indexed is True
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify GraphQL schema regeneration was called
processor.generate_graphql_schema.assert_called_once()
def test_cql_query_building_basic(self):
"""Test basic CQL query construction"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to capture the query
mock_result = []
processor.session.execute.return_value = mock_result
# Create test schema
schema = RowSchema(
name="test_table",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string", indexed=True),
Field(name="status", type="string")
]
)
# Test query building
asyncio = pytest.importorskip("asyncio")
async def run_test():
await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="test_table",
row_schema=schema,
filters={"name": "John", "invalid_filter": "ignored"},
limit=10
)
# Run the async test
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_test())
finally:
loop.close()
# Verify Cassandra connection and query execution
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify the query structure (can't easily test exact query without complex mocking)
call_args = processor.session.execute.call_args
query = call_args[0][0] # First positional argument is the query
params = call_args[0][1] # Second positional argument is parameters
# Basic query structure checks
assert "SELECT * FROM test_user.o_test_table" in query
assert "WHERE" in query
assert "collection = %s" in query
assert "LIMIT 10" in query
# Parameters should include collection and name filter
assert "test_collection" in params
assert "John" in params
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { id name } }',
variables={},
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value'] # keyword argument
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["collection"] == "test_collection"
# Verify result structure
assert "data" in result
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error instead of MagicMock
# Create a simple object to simulate GraphQL error
class MockError:
def __init__(self, message, path, extensions):
self.message = message
self.path = path
self.extensions = extensions
def __str__(self):
return self.message
mock_error = MockError(
message="Field 'invalid_field' doesn't exist",
path=["customers", "0", "invalid_field"],
extensions={"code": "FIELD_NOT_FOUND"}
)
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
user="test_user",
collection="test_collection"
)
# Verify error handling
assert "errors" in result
assert len(result["errors"]) == 1
error = result["errors"][0]
assert error["message"] == "Field 'invalid_field' doesn't exist"
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
assert error["path"] == ["customers", "0", "invalid_field"]
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
def test_schema_generation_basic_structure(self):
"""Test basic GraphQL schema generation structure"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Test individual type creation (avoiding the full schema generation which has annotation issues)
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify type was created
assert len(processor.graphql_types) == 1
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None
@pytest.mark.asyncio
async def test_message_processing_success(self):
"""Test successful message processing flow"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock successful query result
processor.execute_graphql_query.return_value = {
"data": {"customers": [{"id": "1", "name": "John"}]},
"errors": [],
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
"extensions": {}
}
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
collection="test_collection",
query='{ customers { id name } }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-123"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
query='{ customers { id name } }',
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is None
assert '"customers"' in response_call.data # JSON encoded
assert len(response_call.errors) == 0
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock query execution error
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-456"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify error response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify error response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is not None
assert response_call.error.type == "objects-query-error"
assert response_call.error.type == "rows-query-error"
assert "No schema available" in response_call.error.message
assert response_call.data is None
class TestCQLQueryGeneration:
"""Test CQL query generation logic in isolation"""
def test_partition_key_inclusion(self):
"""Test that collection is always included in queries"""
class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Mock the query building (simplified version)
keyspace = processor.sanitize_name("test_user")
table = processor.sanitize_table("test_table")
query = f"SELECT * FROM {keyspace}.{table}"
where_clauses = ["collection = %s"]
assert "collection = %s" in where_clauses
assert keyspace == "test_user"
assert table == "o_test_table"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string")
]
)
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"category": "electronics"},
limit=10
)
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
assert "SELECT data, source FROM test_user.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
assert "index_value = %s" in query
assert params[0] == "test_collection"
assert params[1] == "products"
assert params[2] == "category"
assert params[3] == ["electronics"]
# Verify results
assert len(results) == 1
assert results[0]["id"] == "123"
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string"), # Not indexed
Field(name="price", type="string") # Not indexed
]
)
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"name": "Product A"},
limit=10
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
assert "ALLOW FILTERING" in query
# Should post-filter results
assert len(results) == 1
assert results[0]["name"] == "Product A"
class TestFilterMatching:
"""Test filter matching logic"""
def test_matches_filters_exact_match(self):
"""Test exact match filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active", "name": "test"}
assert processor._matches_filters(row, {"status": "active"}, schema) is True
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
def test_matches_filters_comparison_operators(self):
"""Test comparison operators in filters"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
row = {"price": "100.0"}
# Greater than
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
# Less than
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
# Greater than or equal
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
# Less than or equal
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
def test_matches_filters_contains(self):
"""Test contains filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
row = {"description": "A great product for everyone"}
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
def test_matches_filters_in_list(self):
"""Test in-list filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active"}
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
class TestIndexedFieldFiltering:
"""Test that only indexed or primary key fields can be directly filtered"""
def test_indexed_field_filtering(self):
"""Test that only indexed or primary key fields can be filtered"""
# Create schema with mixed field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="normal_field", type="string", indexed=False),
Field(name="another_field", type="string")
]
)
filters = {
"id": "test123", # Primary key - should be included
"indexed_field": "value", # Indexed - should be included
"normal_field": "ignored", # Not indexed - should be ignored
"another_field": "also_ignored" # Not indexed - should be ignored
}
# Simulate the filtering logic from the processor
valid_filters = []
for field_name, value in filters.items():
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
schema_field = next((f for f in schema.fields if f.name == field_name), None)
if schema_field and (schema_field.indexed or schema_field.primary):
valid_filters.append((field_name, value))
# Only id and indexed_field should be included
assert len(valid_filters) == 2
field_names = [f[0] for f in valid_filters]
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
assert "indexed_field" in field_names
assert "normal_field" not in field_names
assert "another_field" not in field_names
class TestGraphQLSchemaGeneration:
"""Test GraphQL schema generation in detail"""
def test_field_type_annotations(self):
"""Test that GraphQL types have correct field annotations"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create schema with various field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", required=True, primary=True),
Field(name="count", type="integer", required=True),
Field(name="price", type="float", required=False),
Field(name="active", type="boolean", required=False),
Field(name="optional_text", type="string", required=False)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test", schema)
# Verify type was created successfully
assert graphql_type is not None
def test_basic_type_creation(self):
"""Test that GraphQL types are created correctly"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create GraphQL type directly
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify customer type was created
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None

View file

@ -5,8 +5,8 @@ Tests for Cassandra triples query service
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.cassandra.service import Processor
from trustgraph.schema import Value
from trustgraph.query.triples.cassandra.service import Processor, create_term
from trustgraph.schema import Term, IRI, LITERAL
class TestCassandraQueryProcessor:
@ -21,94 +21,101 @@ class TestCassandraQueryProcessor:
graph_host='localhost'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_term_with_http_uri(self, processor):
"""Test create_term with HTTP URI"""
result = create_term("http://example.com/resource")
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
def test_create_term_with_https_uri(self, processor):
"""Test create_term with HTTPS URI"""
result = create_term("https://example.com/resource")
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_term_with_literal(self, processor):
"""Test create_term with literal value"""
result = create_term("just a literal string")
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
def test_create_term_with_empty_string(self, processor):
"""Test create_term with empty string"""
result = create_term("")
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
def test_create_term_with_partial_uri(self, processor):
"""Test create_term with string that looks like URI but isn't complete"""
result = create_term("http")
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
def test_create_term_with_ftp_uri(self, processor):
"""Test create_term with FTP URI (should not be detected as URI)"""
result = create_term("ftp://example.com/file")
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_spo_query(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_spo_query(self, mock_kg_class):
"""Test querying triples with subject, predicate, and object specified"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
# Setup mock TrustGraph via factory function
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results (with mock graph attribute)
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
cassandra_host='localhost'
)
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
result = await processor.query_triples(query)
# Verify KnowledgeGraph was created with correct parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
)
# Verify result contains the queried triple
assert len(result) == 1
assert result[0].s.value == 'test_subject'
@ -143,154 +150,174 @@ class TestCassandraQueryProcessor:
assert processor.table is None
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_kg_class):
"""Test SP query pattern (subject and predicate, no object)"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph and response
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
# Setup mock TrustGraph via factory function
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_sp.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=50
)
result = await processor.query_triples(query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_kg_class):
"""Test S query pattern (subject only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_s.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=25
)
result = await processor.query_triples(query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_kg_class):
"""Test P query pattern (predicate only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_p.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=10
)
result = await processor.query_triples(query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_kg_class):
"""Test O query pattern (object only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.p = 'result_predicate'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_o.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value='test_object', is_uri=False),
o=Term(type=LITERAL, value='test_object'),
limit=75
)
result = await processor.query_triples(query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_kg_class):
"""Test query pattern with no constraints (get all)"""
from trustgraph.schema import TriplesQueryRequest
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'all_subject'
mock_result.p = 'all_predicate'
mock_result.o = 'all_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_all.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
@ -299,9 +326,9 @@ class TestCassandraQueryProcessor:
o=None,
limit=1000
)
result = await processor.query_triples(query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
@ -372,37 +399,44 @@ class TestCassandraQueryProcessor:
run()
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o, g) quad pattern, some values may be\nnull. Output is a list of quads.\n')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_kg_class):
"""Test querying with username and password authentication"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(
taskgroup=MagicMock(),
cassandra_username='authuser',
cassandra_password='authpass'
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
await processor.query_triples(query)
# Verify KnowledgeGraph was created with authentication
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='test_user',
username='authuser',
@ -410,128 +444,154 @@ class TestCassandraQueryProcessor:
)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_kg_class):
"""Test that TrustGraph is reused for same table"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
# First query should create TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1
assert mock_kg_class.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1 # Should not increase
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_switching(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_table_switching(self, mock_kg_class):
"""Test table switching creates new TrustGraph"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
# Setup mock results for both instances
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.p = 'p'
mock_result.o = 'o'
mock_tg_instance1.get_s.return_value = [mock_result]
mock_tg_instance2.get_s.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=100
)
await processor.query_triples(query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=100
)
await processor.query_triples(query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
assert mock_trustgraph.call_count == 2
assert mock_kg_class.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_kg_class):
"""Test exception handling during query execution"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_kg_class):
"""Test query returning multiple results"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock multiple results
mock_result1 = MagicMock()
mock_result1.o = 'object1'
mock_result1.g = ''
mock_result1.otype = None
mock_result1.dtype = None
mock_result1.lang = None
mock_result2 = MagicMock()
mock_result2.o = 'object2'
mock_result2.g = ''
mock_result2.otype = None
mock_result2.dtype = None
mock_result2.lang = None
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=100
)
result = await processor.query_triples(query)
assert len(result) == 2
assert result[0].o.value == 'object1'
assert result[1].o.value == 'object2'
@ -541,16 +601,20 @@ class TestCassandraQueryPerformanceOptimizations:
"""Test cases for multi-table performance optimizations in query service"""
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_po_query_optimization(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_get_po_query_optimization(self, mock_kg_class):
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_po.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
@ -560,8 +624,8 @@ class TestCassandraQueryPerformanceOptimizations:
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=50
)
@ -569,7 +633,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', limit=50
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
)
assert len(result) == 1
@ -578,16 +642,20 @@ class TestCassandraQueryPerformanceOptimizations:
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_os_query_optimization(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_get_os_query_optimization(self, mock_kg_class):
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_os.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
@ -596,9 +664,9 @@ class TestCassandraQueryPerformanceOptimizations:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=Value(value='test_object', is_uri=False),
o=Term(type=LITERAL, value='test_object'),
limit=25
)
@ -606,7 +674,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', limit=25
'test_collection', 'test_object', 'test_subject', g=None, limit=25
)
assert len(result) == 1
@ -615,13 +683,13 @@ class TestCassandraQueryPerformanceOptimizations:
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_kg_class):
"""Test that all query patterns route to their optimal tables"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
@ -655,9 +723,9 @@ class TestCassandraQueryPerformanceOptimizations:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value=s, is_uri=False) if s else None,
p=Value(value=p, is_uri=False) if p else None,
o=Value(value=o, is_uri=False) if o else None,
s=Term(type=LITERAL, value=s) if s else None,
p=Term(type=LITERAL, value=p) if p else None,
o=Term(type=LITERAL, value=o) if o else None,
limit=10
)
@ -687,19 +755,23 @@ class TestCassandraQueryPerformanceOptimizations:
# Mode is determined in KnowledgeGraph initialization
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_kg_class):
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock multiple subjects for the same predicate-object pair
mock_results = []
for i in range(5):
mock_result = MagicMock()
mock_result.s = f'subject_{i}'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
@ -711,8 +783,8 @@ class TestCassandraQueryPerformanceOptimizations:
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
o=Value(value='http://example.com/Person', is_uri=True),
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
o=Term(type=IRI, iri='http://example.com/Person'),
limit=1000
)
@ -723,14 +795,15 @@ class TestCassandraQueryPerformanceOptimizations:
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',
g=None,
limit=1000
)
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}'
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.is_uri is True
assert triple.o.value == 'http://example.com/Person'
assert triple.o.is_uri is True
assert triple.s.value == f'subject_{i}' # Mock returns literal values
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.type == IRI
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
assert triple.o.type == IRI

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.falkordb.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestFalkorDBQueryProcessor:
@ -25,50 +25,50 @@ class TestFalkorDBQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
@ -125,9 +125,9 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -138,8 +138,8 @@ class TestFalkorDBQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -166,8 +166,8 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -211,9 +211,9 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -224,12 +224,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -256,7 +256,7 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
@ -269,13 +269,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -302,8 +302,8 @@ class TestFalkorDBQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -314,12 +314,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -347,7 +347,7 @@ class TestFalkorDBQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -359,13 +359,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -393,7 +393,7 @@ class TestFalkorDBQueryProcessor:
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -404,12 +404,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -449,13 +449,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -476,7 +476,7 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestMemgraphQueryProcessor:
@ -25,50 +25,50 @@ class TestMemgraphQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
@ -124,9 +124,9 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -137,8 +137,8 @@ class TestMemgraphQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -166,8 +166,8 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -212,9 +212,9 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -225,12 +225,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -258,7 +258,7 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
@ -271,13 +271,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -305,8 +305,8 @@ class TestMemgraphQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -317,12 +317,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -351,7 +351,7 @@ class TestMemgraphQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -363,13 +363,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -398,7 +398,7 @@ class TestMemgraphQueryProcessor:
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -409,12 +409,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -455,13 +455,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -480,7 +480,7 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestNeo4jQueryProcessor:
@ -25,50 +25,50 @@ class TestNeo4jQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
@ -124,9 +124,9 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -137,8 +137,8 @@ class TestNeo4jQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@ -166,8 +166,8 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestNeo4jQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
@ -225,13 +225,13 @@ class TestNeo4jQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
@ -250,12 +250,12 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=None,
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
# Verify objects query service was called correctly
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
assert response.error is not None
assert "empty GraphQL query" in response.error.message
async def test_objects_query_service_error(self, processor):
async def test_rows_query_service_error(self, processor):
"""Test handling of objects query service errors"""
# Arrange
request = StructuredQueryRequest(
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service error
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
data=None,
errors=None,
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
response = response_call[0][0]
assert response.error is not None
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "Table 'customers' not found" in response.error.message
async def test_graphql_errors_handling(self, processor):
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None,
errors=graphql_errors,
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
errors=None
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # Null data
errors=None,
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_password is None
class TestObjectsWriterConfiguration:
class TestRowsWriterConfiguration:
"""Test Cassandra configuration in objects writer processor."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_environment_variable_configuration(self, mock_cluster):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
assert processor.cassandra_username == 'obj-env-user'
assert processor.cassandra_password == 'obj-env-pass'
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
"""Test that Cassandra connection uses hosts list correctly."""
env_vars = {
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify cluster was called with hosts list
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
assert 'contact_points' in call_args.kwargs
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
"""Test authentication is configured when credentials are provided."""
env_vars = {
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify auth provider was created with correct credentials
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
def test_objects_writer_add_args(self):
"""Test that objects writer adds standard Cassandra arguments."""
import argparse
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
parser = argparse.ArgumentParser()
ObjectsWriter.add_args(parser)
RowsWriter.add_args(parser)
# Parse empty args to check that arguments exist
args = parser.parse_args([])

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.graph_embeddings.milvus.write import Processor
from trustgraph.schema import Value, EntityEmbeddings
from trustgraph.schema import Term, EntityEmbeddings, IRI, LITERAL
class TestMilvusGraphEmbeddingsStorageProcessor:
@ -22,11 +22,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Value(value='http://example.com/entity1', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity1'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value='literal entity', is_uri=False),
entity=Term(type=LITERAL, value='literal entity'),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
@ -84,7 +84,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
@ -136,7 +136,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
entity=Term(type=LITERAL, value=''),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
@ -155,7 +155,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
entity=Term(type=LITERAL, value=None),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
@ -174,15 +174,15 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
entity=Value(value='http://example.com/valid', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/valid'),
vectors=[[0.1, 0.2, 0.3]]
)
empty_entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
entity=Term(type=LITERAL, value=''),
vectors=[[0.4, 0.5, 0.6]]
)
none_entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
entity=Term(type=LITERAL, value=None),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [valid_entity, empty_entity, none_entity]
@ -217,7 +217,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[]
)
message.entities = [entity]
@ -236,7 +236,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
@ -269,11 +269,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
entity=Value(value='http://example.com/uri_entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
vectors=[[0.1, 0.2, 0.3]]
)
literal_entity = EntityEmbeddings(
entity=Value(value='literal entity text', is_uri=False),
entity=Term(type=LITERAL, value='literal entity text'),
vectors=[[0.4, 0.5, 0.6]]
)
message.entities = [uri_entity, literal_entity]

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
from trustgraph.schema import IRI, LITERAL
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@ -67,7 +68,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'test_entity'
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.entities = [mock_entity]
@ -120,11 +122,13 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
mock_entity1.entity.value = 'entity_one'
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity2 = MagicMock()
mock_entity2.entity.value = 'entity_two'
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity1, mock_entity2]
@ -179,7 +183,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'multi_vector_entity'
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'multi_vector_entity'
mock_entity.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
@ -231,11 +236,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
mock_entity_empty.entity.type = LITERAL
mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]]
mock_entity_none = MagicMock()
mock_entity_none.entity.value = None # None value
mock_entity_none.entity = None # None entity
mock_entity_none.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity_empty, mock_entity_none]

View file

@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch, call
from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor
from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor
from trustgraph.schema import Triples, Triple, Value, Metadata
from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest
@ -60,9 +60,9 @@ class TestNeo4jUserCollectionIsolation:
)
triple = Triple(
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal_value", is_uri=False)
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal_value")
)
message = Triples(
@ -128,9 +128,9 @@ class TestNeo4jUserCollectionIsolation:
metadata = Metadata(id="test-id")
triple = Triple(
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="http://example.com/object", is_uri=True)
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=IRI, iri="http://example.com/object")
)
message = Triples(
@ -170,8 +170,8 @@ class TestNeo4jUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None
)
@ -254,9 +254,9 @@ class TestNeo4jUserCollectionIsolation:
metadata=Metadata(user="user1", collection="coll1"),
triples=[
Triple(
s=Value(value="http://example.com/user1/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="user1_data", is_uri=False)
s=Term(type=IRI, iri="http://example.com/user1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data")
)
]
)
@ -265,9 +265,9 @@ class TestNeo4jUserCollectionIsolation:
metadata=Metadata(user="user2", collection="coll2"),
triples=[
Triple(
s=Value(value="http://example.com/user2/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="user2_data", is_uri=False)
s=Term(type=IRI, iri="http://example.com/user2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data")
)
]
)
@ -429,9 +429,9 @@ class TestNeo4jUserCollectionRegression:
metadata=Metadata(user="user1", collection="coll1"),
triples=[
Triple(
s=Value(value=shared_uri, is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="user1_value", is_uri=False)
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value")
)
]
)
@ -440,9 +440,9 @@ class TestNeo4jUserCollectionRegression:
metadata=Metadata(user="user2", collection="coll2"),
triples=[
Triple(
s=Value(value=shared_uri, is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="user2_value", is_uri=False)
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value")
)
]
)

View file

@ -1,533 +0,0 @@
"""
Unit tests for Cassandra Object Storage Processor
Tests the business logic of the object storage processor including:
- Schema configuration handling
- Type conversions
- Name sanitization
- Table structure generation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestObjectsCassandraStorageLogic:
"""Test business logic without FlowProcessor dependencies"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns (back to original logic)
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
def test_get_cassandra_type(self):
"""Test field type conversion to Cassandra types"""
processor = MagicMock()
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_cassandra_type("string") == "text"
assert processor.get_cassandra_type("boolean") == "boolean"
assert processor.get_cassandra_type("timestamp") == "timestamp"
assert processor.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert processor.get_cassandra_type("integer", size=2) == "int"
assert processor.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert processor.get_cassandra_type("float", size=2) == "float"
assert processor.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert processor.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self):
"""Test value conversion for different field types"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Integer conversions
assert processor.convert_value("123", "integer") == 123
assert processor.convert_value(123.5, "integer") == 123
assert processor.convert_value(None, "integer") is None
# Float conversions
assert processor.convert_value("123.45", "float") == 123.45
assert processor.convert_value(123, "float") == 123.0
# Boolean conversions
assert processor.convert_value("true", "boolean") is True
assert processor.convert_value("false", "boolean") is False
assert processor.convert_value("1", "boolean") is True
assert processor.convert_value("0", "boolean") is False
assert processor.convert_value("yes", "boolean") is True
assert processor.convert_value("no", "boolean") is False
# String conversions
assert processor.convert_value(123, "string") == "123"
assert processor.convert_value(True, "string") == "True"
def test_table_creation_cql_generation(self):
"""Test CQL generation for table creation"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="customer_records",
description="Test customer schema",
fields=[
Field(
name="customer_id",
type="string",
size=50,
primary=True,
required=True,
indexed=False
),
Field(
name="email",
type="string",
size=100,
required=True,
indexed=True
),
Field(
name="age",
type="integer",
size=4,
required=False,
indexed=False
)
]
)
# Call ensure_table
processor.ensure_table("test_user", "customer_records", schema)
# Verify keyspace was ensured (check that it was added to known_keyspaces)
assert "test_user" in processor.known_keyspaces
# Check the CQL that was executed (first call should be table creation)
all_calls = processor.session.execute.call_args_list
table_creation_cql = all_calls[0][0][0] # First call
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
assert "collection text" in table_creation_cql
assert "customer_id text" in table_creation_cql
assert "email text" in table_creation_cql
assert "age int" in table_creation_cql
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
def test_table_creation_without_primary_key(self):
"""Test table creation when no primary key is defined"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema without primary key
schema = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Call ensure_table
processor.ensure_table("test_user", "events", schema)
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
executed_cql = processor.session.execute.call_args[0][0]
assert "synthetic_id uuid" in executed_cql
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "balance",
"type": "float",
"size": 8
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
# Note: Field.required always returns False due to Pulsar schema limitations
# The actual required value is tracked during schema parsing
@pytest.mark.asyncio
async def test_object_processing_logic(self):
"""Test the logic for processing ExtractedObject"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "456"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
# Verify insert was executed (keyspace normal, table with o_ prefix)
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value (from values[0])
assert values[2] == 456 # converted integer value (from values[0])
def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
description="Product catalog",
fields=[
Field(name="product_id", type="string", size=50, primary=True),
Field(name="category", type="string", size=30, indexed=True),
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
# Check index creation calls (table has o_ prefix, fields don't)
calls = processor.session.execute.call_args_list
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls)
class TestObjectsCassandraStorageBatchLogic:
"""Test batch processing logic in Cassandra storage"""
@pytest.mark.asyncio
async def test_batch_object_processing_logic(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
description="Test batch schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="value", type="integer", size=4)
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First", "value": "100"},
{"id": "002", "name": "Second", "value": "200"},
{"id": "003", "name": "Third", "value": "300"}
],
confidence=0.95,
source_span="batch source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = batch_obj
# Process batch object
await processor.on_object(msg, None, None)
# Verify table was ensured once
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
# Verify 3 separate insert calls (one per batch item)
assert processor.session.execute.call_count == 3
# Check each insert call
calls = processor.session.execute.call_args_list
for i, call in enumerate(calls):
insert_cql = call[0][0]
values = call[0][1]
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
assert "collection" in insert_cql
# Check values for each batch item
assert values[0] == "batch_collection" # collection
assert values[1] == f"00{i+1}" # id from batch item i
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
assert values[3] == (i+1) * 100 # converted integer value
@pytest.mark.asyncio
async def test_empty_batch_processing_logic(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="empty source"
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
# Process empty batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@pytest.mark.asyncio
async def test_single_item_batch_processing_logic(self):
"""Test processing of single-item batch (backward compatibility)"""
processor = MagicMock()
processor.schemas = {
"single_schema": RowSchema(
name="single_schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(
metadata=Metadata(
id="single-001",
user="test_user",
collection="single_collection",
metadata=[]
),
schema_name="single_schema",
values=[{"id": "single-1", "data": "single data"}], # Array with one item
confidence=0.8,
source_span="single source"
)
msg = MagicMock()
msg.value.return_value = single_batch_obj
# Process single-item batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify exactly one insert call
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_single_schema" in insert_cql
assert values[0] == "single_collection" # collection
assert values[1] == "single-1" # id value
assert values[2] == "single data" # data value
def test_batch_value_conversion_logic(self):
"""Test value conversion works correctly for batch items"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Test various conversion scenarios that would occur in batch processing
test_cases = [
# Integer conversions for batch items
("123", "integer", 123),
("456", "integer", 456),
("789", "integer", 789),
# Float conversions for batch items
("12.5", "float", 12.5),
("34.7", "float", 34.7),
# Boolean conversions for batch items
("true", "boolean", True),
("false", "boolean", False),
("1", "boolean", True),
("0", "boolean", False),
# String conversions for batch items
(123, "string", "123"),
(45.6, "string", "45.6"),
]
for input_val, field_type, expected_output in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"

View file

@ -0,0 +1,435 @@
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant row embeddings storage functionality"""
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_basic(self, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key='test-api-key'
)
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
"""Test processor initialization with default values"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key=None
)
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_sanitize_name(self, mock_qdrant_client):
"""Test name sanitization for Qdrant collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Test basic sanitization
assert processor.sanitize_name("simple") == "simple"
assert processor.sanitize_name("with-dash") == "with_dash"
assert processor.sanitize_name("with.dot") == "with_dot"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
# Test numeric prefix handling
assert processor.sanitize_name("123start") == "r_123start"
assert processor.sanitize_name("_underscore") == "r__underscore"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_get_collection_name(self, mock_qdrant_client):
"""Test Qdrant collection name generation"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
"""Test that ensure_collection creates a new collection when needed"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
"""Test that ensure_collection skips creation when collection exists"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
"""Test that ensure_collection uses cache for previously created collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
"""Test processing basic row embeddings message"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid-123'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='customers',
embeddings=[embedding]
)
# Mock message wrapper
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['index_name'] == 'customer_id'
assert point.payload['index_value'] == ['CUST001']
assert point.payload['text'] == 'CUST001'
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='people',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
"""Test that embeddings with no vectors are skipped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with no vectors
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
"""Test that messages for unknown collections are dropped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection(self, mock_qdrant_client):
"""Test deleting all collections for a user/collection"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
"""Test deleting collections for a specific schema"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,474 @@
"""
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
Tests the business logic of the row storage processor including:
- Schema configuration handling
- Name sanitization
- Unified table structure
- Index management
- Row storage with multi-index support
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
# Schema with primary and indexed fields
schema = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed fields
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names # Not indexed
assert len(index_names) == 3
def test_get_index_names_no_indexes(self):
"""Test schema with no indexed fields"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="no_index_schema",
fields=[
Field(name="data1", type="string"),
Field(name="data2", type="string")
]
)
index_names = processor.get_index_names(schema)
assert len(index_names) == 0
def test_build_index_value(self):
"""Test building index values from row data"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
# Single field index
result = processor.build_index_value(value_map, "id")
assert result == ["123"]
result = processor.build_index_value(value_map, "category")
assert result == ["electronics"]
# Missing field returns empty string
result = processor.build_index_value(value_map, "missing")
assert result == [""]
def test_build_index_value_composite(self):
"""Test building composite index values"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
# Composite index (comma-separated field names)
result = processor.build_index_value(value_map, "region,category")
assert result == ["us-west", "electronics"]
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "category",
"type": "string",
"indexed": True
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
@pytest.mark.asyncio
async def test_object_processing_stores_data_map(self):
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "test_data"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify insert was executed
processor.session.execute.assert_called()
insert_call = processor.session.execute.call_args
insert_cql = insert_call[0][0]
values = insert_call[0][1]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
assert values[1] == "test_schema" # schema_name
assert values[2] == "id" # index_name (primary key field)
assert values[3] == ["123"] # index_value as list
assert values[4] == {"id": "123", "value": "test_data"} # data map
assert values[5] == "" # source
@pytest.mark.asyncio
async def test_object_processing_multiple_indexes(self):
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="multi_index_schema",
values=[{"id": "123", "category": "electronics", "status": "active"}],
confidence=0.9,
source_span=""
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per indexed field: id, category, status)
assert processor.session.execute.call_count == 3
# Check that different index_names were used
index_names_used = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
index_names_used.add(values[2]) # index_name is 3rd value
assert index_names_used == {"id", "category", "status"}
class TestRowsCassandraStorageBatchLogic:
"""Test batch processing logic for unified table implementation"""
@pytest.mark.asyncio
async def test_batch_object_processing(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First"},
{"id": "002", "name": "Second"},
{"id": "003", "name": "Third"}
],
confidence=0.95,
source_span=""
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert processor.session.execute.call_count == 3
# Check each insert has different id
ids_inserted = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
ids_inserted.add(tuple(values[3])) # index_value is 4th value
assert ids_inserted == {("001",), ("002",), ("003",)}
@pytest.mark.asyncio
async def test_empty_batch_processing(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span=""
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
class TestUnifiedTableStructure:
"""Test the unified rows table structure"""
def test_ensure_tables_creates_unified_structure(self):
"""Test that ensure_tables creates the unified rows table"""
processor = MagicMock()
processor.known_keyspaces = {"test_user"}
processor.tables_initialized = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.ensure_keyspace = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should have 2 calls: create rows table + create row_partitions table
assert processor.session.execute.call_count == 2
# Check rows table creation
rows_cql = processor.session.execute.call_args_list[0][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
assert "collection text" in rows_cql
assert "schema_name text" in rows_cql
assert "index_name text" in rows_cql
assert "index_value frozen<list<text>>" in rows_cql
assert "data map<text, text>" in rows_cql
assert "source text" in rows_cql
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
# Check row_partitions table creation
partitions_cql = processor.session.execute.call_args_list[1][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
# Verify keyspace added to initialized set
assert "test_user" in processor.tables_initialized
def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent"""
processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized
processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should not execute any CQL since already initialized
processor.session.execute.assert_not_called()
class TestPartitionRegistration:
"""Test partition registration for tracking what's stored"""
def test_register_partitions(self):
"""Test registering partitions for a collection/schema pair"""
processor = MagicMock()
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
# Verify cache was updated
assert ("test_collection", "test_schema") in processor.registered_partitions
def test_register_partitions_idempotent(self):
"""Test that partition registration is idempotent"""
processor = MagicMock()
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -6,7 +6,8 @@ import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.cassandra.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Triple, LITERAL, IRI
from trustgraph.direct.cassandra_kg import DEFAULT_GRAPH
class TestCassandraStorageProcessor:
@ -86,29 +87,29 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_username == 'new-user' # Only cassandra_* params work
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_with_auth(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_switching_with_auth(self, mock_kg_class):
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(
taskgroup=taskgroup_mock,
cassandra_username='testuser',
cassandra_password='testpass'
)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph was called with auth parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user1',
username='testuser',
@ -117,128 +118,150 @@ class TestCassandraStorageProcessor:
assert processor.table == 'user1'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_without_auth(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_switching_without_auth(self, mock_kg_class):
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph was called without auth parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user2'
)
assert processor.table == 'user2'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_reuse_when_same(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_reuse_when_same(self, mock_kg_class):
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1
assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1 # Should not increase
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_triple_insertion(self, mock_kg_class):
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
# Create mock triples with proper Term structure
triple1 = MagicMock()
triple1.s.type = LITERAL
triple1.s.value = 'subject1'
triple1.s.datatype = ''
triple1.s.language = ''
triple1.p.type = LITERAL
triple1.p.value = 'predicate1'
triple1.o.type = LITERAL
triple1.o.value = 'object1'
triple1.o.datatype = ''
triple1.o.language = ''
triple1.g = None
triple2 = MagicMock()
triple2.s.type = LITERAL
triple2.s.value = 'subject2'
triple2.s.datatype = ''
triple2.s.language = ''
triple2.p.type = LITERAL
triple2.p.value = 'predicate2'
triple2.o.type = LITERAL
triple2.o.value = 'object2'
triple2.o.datatype = ''
triple2.o.language = ''
triple2.g = None
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Verify both triples were inserted
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1')
mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2')
mock_tg_instance.insert.assert_any_call(
'collection1', 'subject1', 'predicate1', 'object1',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
mock_tg_instance.insert.assert_any_call(
'collection1', 'subject2', 'predicate2', 'object2',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_triple_insertion_with_empty_list(self, mock_kg_class):
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
"""Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock()
mock_trustgraph.side_effect = Exception("Connection failed")
mock_kg_class.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
@ -326,92 +349,104 @@ class TestCassandraStorageProcessor:
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_kg_class):
"""Test table switching when different tables are used in sequence"""
taskgroup_mock = MagicMock()
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
processor = Processor(taskgroup=taskgroup_mock)
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
assert mock_trustgraph.call_count == 2
assert mock_kg_class.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_kg_class):
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
# Create triple with special characters and proper Term structure
triple = MagicMock()
triple.s.type = LITERAL
triple.s.value = 'subject with spaces & symbols'
triple.s.datatype = ''
triple.s.language = ''
triple.p.type = LITERAL
triple.p.value = 'predicate:with/colons'
triple.o.type = LITERAL
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
triple.o.datatype = ''
triple.o.language = ''
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
'test_collection',
'subject with spaces & symbols',
'predicate:with/colons',
'object with "quotes" and unicode: ñáéíóú'
'object with "quotes" and unicode: ñáéíóú',
g=DEFAULT_GRAPH,
otype='l',
dtype='',
lang=''
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
"""Test that table remains unchanged when TrustGraph creation fails"""
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_trustgraph.side_effect = Exception("Connection failed")
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
@ -422,12 +457,12 @@ class TestCassandraPerformanceOptimizations:
"""Test cases for multi-table performance optimizations"""
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_legacy_mode_uses_single_table(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_legacy_mode_uses_single_table(self, mock_kg_class):
"""Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
@ -440,16 +475,15 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance uses legacy mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
assert mock_tg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_optimized_mode_uses_multi_table(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_optimized_mode_uses_multi_table(self, mock_kg_class):
"""Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
@ -462,24 +496,31 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance is in optimized mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
assert mock_tg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_batch_write_consistency(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_batch_write_consistency(self, mock_kg_class):
"""Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create test triple
# Create test triple with proper Term structure
triple = MagicMock()
triple.s.type = LITERAL
triple.s.value = 'test_subject'
triple.s.datatype = ''
triple.s.language = ''
triple.p.type = LITERAL
triple.p.value = 'test_predicate'
triple.o.type = LITERAL
triple.o.value = 'test_object'
triple.o.datatype = ''
triple.o.language = ''
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
@ -490,7 +531,8 @@ class TestCassandraPerformanceOptimizations:
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object'
'collection1', 'test_subject', 'test_predicate', 'test_object',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
def test_environment_variable_controls_mode(self):

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.falkordb.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestFalkorDBStorageProcessor:
@ -22,9 +22,9 @@ class TestFalkorDBStorageProcessor:
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
message.triples = [triple]
@ -183,9 +183,9 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=IRI, iri='http://example.com/object')
)
message.triples = [triple]
@ -269,14 +269,14 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object1')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]
@ -337,14 +337,14 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestMemgraphStorageProcessor:
@ -22,9 +22,9 @@ class TestMemgraphStorageProcessor:
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
message.triples = [triple]
@ -231,9 +231,9 @@ class TestMemgraphStorageProcessor:
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=IRI, iri='http://example.com/object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
@ -265,9 +265,9 @@ class TestMemgraphStorageProcessor:
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
@ -347,14 +347,14 @@ class TestMemgraphStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object1')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.neo4j.write import Processor
from trustgraph.schema import IRI, LITERAL
class TestNeo4jStorageProcessor:
@ -257,10 +258,12 @@ class TestNeo4jStorageProcessor:
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock()
@ -327,10 +330,12 @@ class TestNeo4jStorageProcessor:
# Create mock triple with literal object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal value"
triple.o.is_uri = False
# Create mock message with metadata
mock_message = MagicMock()
@ -398,16 +403,20 @@ class TestNeo4jStorageProcessor:
# Create mock triples
triple1 = MagicMock()
triple1.s.value = "http://example.com/subject1"
triple1.p.value = "http://example.com/predicate1"
triple1.o.value = "http://example.com/object1"
triple1.o.is_uri = True
triple1.s.type = IRI
triple1.s.iri = "http://example.com/subject1"
triple1.p.type = IRI
triple1.p.iri = "http://example.com/predicate1"
triple1.o.type = IRI
triple1.o.iri = "http://example.com/object1"
triple2 = MagicMock()
triple2.s.value = "http://example.com/subject2"
triple2.p.value = "http://example.com/predicate2"
triple2.s.type = IRI
triple2.s.iri = "http://example.com/subject2"
triple2.p.type = IRI
triple2.p.iri = "http://example.com/predicate2"
triple2.o.type = LITERAL
triple2.o.value = "literal value"
triple2.o.is_uri = False
# Create mock message with metadata
mock_message = MagicMock()
@ -550,10 +559,12 @@ class TestNeo4jStorageProcessor:
# Create triple with special characters
triple = MagicMock()
triple.s.value = "http://example.com/subject with spaces"
triple.p.value = "http://example.com/predicate:with/symbols"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject with spaces"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate:with/symbols"
triple.o.type = LITERAL
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
triple.o.is_uri = False
mock_message = MagicMock()
mock_message.triples = [triple]

View file

@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'client')
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4 # 4 safety categories
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Assert
# Verify Google AI Studio client was called with correct API key
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
# Verify processor has the client
assert processor.client == mock_genai_client

View file

@ -1,6 +1,6 @@
"""
Unit tests for trustgraph.model.text_completion.vertexai
Starting simple with one test to get the basics working
Updated for google-genai SDK
"""
import pytest
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
"""Simple test for processor initialization"""
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test basic processor initialization with mocked dependencies"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock the parent class initialization to avoid taskgroup requirement
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert processor.default_model == 'gemini-2.0-flash-001'
assert hasattr(processor, 'generation_configs') # Cache dictionary
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'model_clients') # LLM clients are now cached
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
assert hasattr(processor, 'client') # genai.Client
mock_service_account.Credentials.from_service_account_file.assert_called_once()
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test successful content generation"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response from Gemini"
mock_response.usage_metadata.prompt_token_count = 15
mock_response.usage_metadata.candidates_token_count = 8
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'gemini-2.0-flash-001'
# Check that the method was called (actual prompt format may vary)
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
# Generation config is now created dynamically per model
assert 'generation_config' in call_args[1]
assert call_args[1]['safety_settings'] == processor.safety_settings
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test rate limit error handling"""
# Arrange
from google.api_core.exceptions import ResourceExhausted
from trustgraph.exceptions import TooManyRequests
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of blocked content (safety filters)"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None # Blocked content returns None
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 0
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
"""Test processor initialization without private key (uses default credentials)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Mock google.auth.default() to return credentials and project ID
mock_credentials = MagicMock()
mock_auth_default.return_value = (mock_credentials, "test-project-123")
# Mock GenerativeModel
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
config = {
'region': 'us-central1',
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
project='test-project-123'
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of generic exceptions"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Network error")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = Exception("Network error")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Network error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test processor initialization with custom parameters"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert processor.default_model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
# Verify that generation_config cache exists
assert hasattr(processor, 'generation_configs')
assert processor.generation_configs == {} # Empty cache initially
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
# Verify service account was called with custom key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
# Verify that api_params dict has the correct values (this is accessible)
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify that api_params dict has the correct values
assert processor.api_params["temperature"] == 0.7
assert processor.api_params["max_output_tokens"] == 4096
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test that VertexAI is initialized correctly with credentials"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
# Verify VertexAI init was called with correct parameters
mock_vertexai.init.assert_called_once_with(
# Verify genai.Client was called with correct parameters
mock_genai.Client.assert_called_once_with(
vertexai=True,
project='test-project-123',
location='europe-west1',
credentials=mock_credentials,
project='test-project-123'
credentials=mock_credentials
)
# GenerativeModel is now created lazily on first use, not at initialization
mock_generative_model.assert_not_called()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test content generation with empty prompts"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Default response"
mock_response.usage_metadata.prompt_token_count = 2
mock_response.usage_metadata.candidates_token_count = 3
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gemini-2.0-flash-001'
# Verify the model was called with the combined empty prompts
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# The prompt should be "" + "\n\n" + "" = "\n\n"
assert call_args[0][0] == "\n\n"
# Verify the client was called
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
"""Test Anthropic processor initialization with private key credentials"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-456"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock AnthropicVertex
mock_anthropic_client = MagicMock()
mock_anthropic_vertex.return_value = mock_anthropic_client
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'claude-3-sonnet@20240229'
# is_anthropic logic is now determined dynamically per request
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
# Verify AnthropicVertex was initialized with credentials
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
mock_anthropic_vertex.assert_called_once_with(
region='us-west1',
project_id='test-project-456',
credentials=mock_credentials
)
# Verify api_params are set correctly
assert processor.api_params["temperature"] == 0.5
assert processor.api_params["max_output_tokens"] == 2048
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
mock_client.models.generate_content.assert_called_once()
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'temperature': 0.2,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
# Verify the call was made with the override model
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-pro"
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
mock_client.models.generate_content.assert_called_once()
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
# Verify the model override was used
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])