mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 09:56:22 +02:00
Merge 2.0 to master (#651)
This commit is contained in:
parent
3666ece2c5
commit
b9d7bf9a8b
212 changed files with 13940 additions and 6180 deletions
|
|
@ -12,7 +12,7 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
|
@ -30,38 +30,16 @@ class TestAgentKgExtractionIntegration:
|
|||
# Mock agent client
|
||||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response
|
||||
# Mock successful agent response in JSONL format
|
||||
def mock_agent_response(recipient, question):
|
||||
# Simulate agent processing and return structured response
|
||||
# Simulate agent processing and return structured JSONL response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
mock_response.answer = '''```json
|
||||
{
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": true
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": true
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."}
|
||||
{"type": "definition", "entity": "Neural Networks", "definition": "Computing systems inspired by biological neural networks that process information."}
|
||||
{"type": "relationship", "subject": "Machine Learning", "predicate": "is_subset_of", "object": "Artificial Intelligence", "object-entity": true}
|
||||
{"type": "relationship", "subject": "Neural Networks", "predicate": "used_in", "object": "Machine Learning", "object-entity": true}
|
||||
```'''
|
||||
return mock_response.answer
|
||||
|
||||
|
|
@ -100,9 +78,9 @@ class TestAgentKgExtractionIntegration:
|
|||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -120,7 +98,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
# Copy the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
|
@ -156,7 +134,7 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_json(agent_response)
|
||||
extraction_data = extractor.parse_jsonl(agent_response)
|
||||
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Add metadata triples
|
||||
|
|
@ -200,15 +178,15 @@ class TestAgentKgExtractionIntegration:
|
|||
assert len(sent_triples.triples) > 0
|
||||
|
||||
# Check that we have definition triples
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION]
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.iri == DEFINITION]
|
||||
assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks
|
||||
|
||||
|
||||
# Check that we have label triples
|
||||
label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL]
|
||||
label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL]
|
||||
assert len(label_triples) >= 2 # Should have labels for entities
|
||||
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) >= 2 # Entities should be linked to document
|
||||
|
||||
# Verify entity contexts were emitted
|
||||
|
|
@ -220,7 +198,7 @@ class TestAgentKgExtractionIntegration:
|
|||
assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities
|
||||
|
||||
# Verify entity URIs are properly formed
|
||||
entity_uris = [ec.entity.value for ec in sent_contexts.entities]
|
||||
entity_uris = [ec.entity.iri for ec in sent_contexts.entities]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
|
|
@ -248,22 +226,28 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of invalid JSON responses from agent"""
|
||||
"""Test handling of invalid JSON responses from agent - JSONL is lenient and skips invalid lines"""
|
||||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises((ValueError, json.JSONDecodeError)):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
# Act - JSONL parsing is lenient, invalid lines are skipped
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should emit triples (with just metadata) but no entity contexts
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
|
|
@ -272,7 +256,8 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
# Return empty JSONL (just empty/whitespace)
|
||||
return ''
|
||||
|
||||
agent_client.invoke = mock_empty_response
|
||||
|
||||
|
|
@ -303,7 +288,8 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}'''
|
||||
# JSONL with definition missing required field
|
||||
return '{"type": "definition", "entity": "Missing Definition"}'
|
||||
|
||||
agent_client.invoke = mock_malformed_response
|
||||
|
||||
|
|
@ -330,7 +316,7 @@ class TestAgentKgExtractionIntegration:
|
|||
def capture_prompt(recipient, question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
return '' # Empty JSONL response
|
||||
|
||||
agent_client.invoke = capture_prompt
|
||||
|
||||
|
|
@ -361,7 +347,7 @@ class TestAgentKgExtractionIntegration:
|
|||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}'
|
||||
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
||||
|
|
@ -398,7 +384,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}'''
|
||||
return '{"type": "definition", "entity": "機械学習", "definition": "人工知能の一分野"}'
|
||||
|
||||
agent_client.invoke = mock_unicode_response
|
||||
|
||||
|
|
@ -415,7 +401,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
# Check that unicode entity was properly processed
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"]
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL and t.o.value == "機械学習"]
|
||||
assert len(entity_labels) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -433,7 +419,7 @@ class TestAgentKgExtractionIntegration:
|
|||
def mock_large_text_response(recipient, question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}'''
|
||||
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'
|
||||
|
||||
agent_client.invoke = mock_large_text_response
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from argparse import ArgumentParser
|
|||
|
||||
# Import processors that use Cassandra configuration
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
|
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
|
|||
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
|
||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
|
||||
"""Test complete flow from environment variables to Cassandra Cluster connection."""
|
||||
env_vars = {
|
||||
|
|
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
# Trigger Cassandra connection
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
class TestMultipleHostsHandling:
|
||||
"""Test multiple Cassandra hosts handling end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
|
||||
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
|
||||
env_vars = {
|
||||
|
|
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify all hosts were passed to Cluster
|
||||
|
|
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
|
|||
class TestAuthenticationFlow:
|
||||
"""Test authentication configuration flow end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is enabled when both username and password are provided."""
|
||||
env_vars = {
|
||||
|
|
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should be created
|
||||
|
|
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
|
|||
assert 'auth_provider' in call_args.kwargs
|
||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when credentials are missing."""
|
||||
env_vars = {
|
||||
|
|
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should not be created
|
||||
|
|
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
|
|||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when only username is provided."""
|
||||
processor = ObjectsWriter(
|
||||
processor = RowsWriter(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='partial-auth-host',
|
||||
cassandra_username='partial-user'
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .cassandra_test_helper import cassandra_container
|
|||
from trustgraph.direct.cassandra_kg import KnowledgeGraph
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
|
||||
from trustgraph.schema import Triple, Term, Metadata, Triples, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -118,19 +118,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value="Alice Smith", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value="Alice Smith")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="25", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value="25")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value="Engineering", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/department"),
|
||||
o=Term(type=LITERAL, value="Engineering")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -181,19 +181,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/bob", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/bob")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="30", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value="30")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/bob", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/charlie", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.org/bob"),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/charlie")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -208,7 +208,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Test S query (find all relationships for Alice)
|
||||
s_query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
|
|
@ -218,18 +218,18 @@ class TestCassandraIntegration:
|
|||
s_results = await query_processor.query_triples(s_query)
|
||||
print(f"Query processor results: {len(s_results)}")
|
||||
for result in s_results:
|
||||
print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}")
|
||||
print(f" S={result.s.iri}, P={result.p.iri}, O={result.o.iri if result.o.type == IRI else result.o.value}")
|
||||
assert len(s_results) == 2
|
||||
|
||||
s_predicates = [t.p.value for t in s_results]
|
||||
|
||||
s_predicates = [t.p.iri for t in s_results]
|
||||
assert "http://example.org/knows" in s_predicates
|
||||
assert "http://example.org/age" in s_predicates
|
||||
print("✓ Subject queries via processor working")
|
||||
|
||||
|
||||
# Test P query (find all "knows" relationships)
|
||||
p_query = TriplesQueryRequest(
|
||||
s=None, # None for wildcard
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
|
|
@ -238,8 +238,8 @@ class TestCassandraIntegration:
|
|||
p_results = await query_processor.query_triples(p_query)
|
||||
print(p_results)
|
||||
assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie
|
||||
|
||||
p_subjects = [t.s.value for t in p_results]
|
||||
|
||||
p_subjects = [t.s.iri for t in p_results]
|
||||
assert "http://example.org/alice" in p_subjects
|
||||
assert "http://example.org/bob" in p_subjects
|
||||
print("✓ Predicate queries via processor working")
|
||||
|
|
@ -262,19 +262,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value=name, is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value=name)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value=str(age), is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value=str(age))
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value=department, is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/department"),
|
||||
o=Term(type=LITERAL, value=department)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -333,36 +333,36 @@ class TestCassandraIntegration:
|
|||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://company.org/Employee")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/bob", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/bob"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://company.org/Employee")
|
||||
),
|
||||
# Relationships
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/reportsTo", is_uri=True),
|
||||
o=Value(value="http://company.org/bob", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/reportsTo"),
|
||||
o=Term(type=IRI, iri="http://company.org/bob")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/worksIn", is_uri=True),
|
||||
o=Value(value="http://company.org/engineering", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/worksIn"),
|
||||
o=Term(type=IRI, iri="http://company.org/engineering")
|
||||
),
|
||||
# Personal info
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/fullName", is_uri=True),
|
||||
o=Value(value="Alice Johnson", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/fullName"),
|
||||
o=Term(type=LITERAL, value="Alice Johnson")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/email", is_uri=True),
|
||||
o=Value(value="alice@company.org", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/email"),
|
||||
o=Term(type=LITERAL, value="alice@company.org")
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,10 +51,10 @@ class MockWebSocket:
|
|||
"metadata": {
|
||||
"id": "test-id",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ async def test_import_graceful_shutdown_integration(mock_backend):
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"object-{i}", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"object-{i}"}}]
|
||||
}
|
||||
messages.append(msg_data)
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"export-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"export-object-{i}", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"export-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"export-object-{i}"}}]
|
||||
}
|
||||
# Create Triples object instead of raw dict
|
||||
from trustgraph.schema import Triples, Metadata
|
||||
|
|
@ -302,7 +302,7 @@ async def test_concurrent_import_export_shutdown():
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"concurrent-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"concurrent-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
await import_handler.receive(msg)
|
||||
|
||||
|
|
@ -359,7 +359,7 @@ async def test_websocket_close_during_message_processing():
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"slow-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"slow-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
task = asyncio.create_task(import_handler.receive(msg))
|
||||
message_tasks.append(task)
|
||||
|
|
@ -423,7 +423,7 @@ async def test_backpressure_during_shutdown():
|
|||
# Simulate receiving and processing a message
|
||||
msg_data = {
|
||||
"metadata": {"id": f"msg-{i}"},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
await ws.send_json(msg_data)
|
||||
# Check if we should stop
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor
|
||||
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
|
||||
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
||||
|
|
@ -147,6 +147,8 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor)
|
||||
processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor)
|
||||
processor.triples_batch_size = 50
|
||||
processor.entity_batch_size = 5
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -156,6 +158,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor)
|
||||
processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor)
|
||||
processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor)
|
||||
processor.triples_batch_size = 50
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -253,24 +256,24 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
if s and o:
|
||||
s_uri = definitions_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
s_term = Term(type=IRI, iri=str(s_uri))
|
||||
o_term = Term(type=LITERAL, value=str(o))
|
||||
|
||||
# Generate triples as the processor would
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=s, is_uri=False)
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=s)
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=o_value
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=DEFINITION),
|
||||
o=o_term
|
||||
))
|
||||
|
||||
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
entity=s_term,
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
|
|
@ -279,16 +282,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
assert len(entities) == 3 # 1 entity context per entity
|
||||
|
||||
# Verify triple structure
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
|
||||
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
|
||||
|
||||
assert len(label_triples) == 3
|
||||
assert len(definition_triples) == 3
|
||||
|
||||
|
||||
# Verify entity contexts
|
||||
for entity in entities:
|
||||
assert entity.entity.is_uri is True
|
||||
assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert entity.entity.type == IRI
|
||||
assert entity.entity.iri.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert len(entity.context) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -309,52 +312,52 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
o = rel["object"]
|
||||
|
||||
|
||||
if s and p and o:
|
||||
s_uri = relationships_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
s_term = Term(type=IRI, iri=str(s_uri))
|
||||
|
||||
p_uri = relationships_processor.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
p_term = Term(type=IRI, iri=str(p_uri))
|
||||
|
||||
if rel["object-entity"]:
|
||||
o_uri = relationships_processor.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
o_term = Term(type=IRI, iri=str(o_uri))
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
o_term = Term(type=LITERAL, value=str(o))
|
||||
|
||||
# Main relationship triple
|
||||
triples.append(Triple(s=s_value, p=p_value, o=o_value))
|
||||
|
||||
triples.append(Triple(s=s_term, p=p_term, o=o_term))
|
||||
|
||||
# Label triples
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(s), is_uri=False)
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(s))
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s=p_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(p), is_uri=False)
|
||||
s=p_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(p))
|
||||
))
|
||||
|
||||
|
||||
if rel["object-entity"]:
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(o), is_uri=False)
|
||||
s=o_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(o))
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) > 0
|
||||
|
||||
# Verify relationship triples exist
|
||||
relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")]
|
||||
relationship_triples = [t for t in triples if t.p.iri.endswith("is_subset_of") or t.p.iri.endswith("is_used_in")]
|
||||
assert len(relationship_triples) >= 2
|
||||
|
||||
|
||||
# Verify label triples
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
|
||||
assert len(label_triples) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -374,9 +377,9 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True),
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=Value(value="A subset of AI", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"),
|
||||
p=Term(type=IRI, iri=DEFINITION),
|
||||
o=Term(type=LITERAL, value="A subset of AI")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -405,9 +408,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
entities=[]
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri="http://example.org/entity"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
|
|
@ -496,12 +504,12 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still call producers but with empty results
|
||||
# Should NOT call producers with empty results (avoids Cassandra NULL issues)
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
triples_producer.send.assert_not_called()
|
||||
entity_contexts_producer.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
|
|
@ -602,9 +610,9 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
collection="test_collection",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc:test", is_uri=True),
|
||||
p=Value(value="dc:title", is_uri=True),
|
||||
o=Value(value="Test Document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc:test"),
|
||||
p=Term(type=IRI, iri="dc:title"),
|
||||
o=Term(type=LITERAL, value="Test Document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
from trustgraph.extract.kg.rows.processor import Processor
|
||||
from trustgraph.schema import (
|
||||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
|
|
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Mock flow with failing prompt service
|
||||
|
|
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
|
|||
|
|
@ -1,608 +0,0 @@
|
|||
"""
|
||||
Integration tests for Cassandra Object Storage
|
||||
|
||||
These tests verify the end-to-end functionality of storing ExtractedObjects
|
||||
in Cassandra, including table creation, data insertion, and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestObjectsCassandraIntegration:
|
||||
"""Integration tests for Cassandra object storage"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
# Extract keyspace name from query
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
# Check if this keyspace was created
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_cluster(self, mock_cassandra_session):
|
||||
"""Mock Cassandra cluster"""
|
||||
cluster = MagicMock()
|
||||
cluster.connect.return_value = mock_cassandra_session
|
||||
cluster.shutdown = MagicMock()
|
||||
return cluster
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
||||
"""Create processor with mocked Cassandra dependencies"""
|
||||
processor = MagicMock()
|
||||
processor.graph_host = "localhost"
|
||||
processor.graph_username = None
|
||||
processor.graph_password = None
|
||||
processor.config_key = "schema"
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.cluster = None
|
||||
processor.session = None
|
||||
|
||||
# Bind actual methods
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
||||
"""Test complete flow from schema config to object storage"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
# Mock Cluster creation
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Step 1: Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer information",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True},
|
||||
{"name": "age", "type": "integer"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||
await processor.create_collection("test_user", "import_2024", {})
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
||||
# Verify keyspace creation
|
||||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
|
||||
assert "collection text" in str(table_calls[0])
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
|
||||
|
||||
# Verify index creation
|
||||
index_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE INDEX" in str(call)]
|
||||
assert len(index_calls) == 1
|
||||
assert "email" in str(index_calls[0])
|
||||
|
||||
# Verify data insertion
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
insert_call = insert_calls[0]
|
||||
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
|
||||
|
||||
# Check inserted values
|
||||
values = insert_call[0][1]
|
||||
assert "import_2024" in values # collection
|
||||
assert "CUST001" in values # customer_id
|
||||
assert "John Doe" in values # name
|
||||
assert "john@example.com" in values # email
|
||||
assert 30 in values # age (converted to int)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_schema_handling(self, processor_with_mocks):
|
||||
"""Test handling multiple schemas and objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure multiple schemas
|
||||
config = {
|
||||
"schema": {
|
||||
"products": json.dumps({
|
||||
"name": "products",
|
||||
"fields": [
|
||||
{"name": "product_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "price", "type": "float"}
|
||||
]
|
||||
}),
|
||||
"orders": json.dumps({
|
||||
"name": "orders",
|
||||
"fields": [
|
||||
{"name": "order_id", "type": "string", "primary_key": True},
|
||||
{"name": "customer_id", "type": "string"},
|
||||
{"name": "total", "type": "float"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
# Create collections first
|
||||
await processor.create_collection("shop", "catalog", {})
|
||||
await processor.create_collection("shop", "sales", {})
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify separate tables were created
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2
|
||||
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
|
||||
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_fields(self, processor_with_mocks):
|
||||
"""Test handling of objects with missing required fields"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema with required field
|
||||
processor.schemas["test_schema"] = RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True, required=True),
|
||||
Field(name="required_field", type="string", size=100, required=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "test", {})
|
||||
|
||||
# Create object missing required field
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123"}], # missing required_field
|
||||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Should still process (Cassandra doesn't enforce NOT NULL)
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert was attempted
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_without_primary_key(self, processor_with_mocks):
|
||||
"""Test handling schemas without defined primary keys"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema without primary key
|
||||
processor.schemas["events"] = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("logger", "app_events", {})
|
||||
|
||||
# Process object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
schema_name="events",
|
||||
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
|
||||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify synthetic_id was added
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "synthetic_id uuid" in str(table_calls[0])
|
||||
|
||||
# Verify insert includes UUID
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
values = insert_calls[0][0][1]
|
||||
# Check that a UUID was generated (will be in values list)
|
||||
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
|
||||
assert uuid_found
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_handling(self, processor_with_mocks):
|
||||
"""Test Cassandra authentication"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
processor.cassandra_username = "cassandra_user"
|
||||
processor.cassandra_password = "cassandra_pass"
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
|
||||
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||
mock_cluster_class.return_value = mock_cluster
|
||||
|
||||
# Trigger connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify authentication was configured
|
||||
mock_auth.assert_called_once_with(
|
||||
username="cassandra_user",
|
||||
password="cassandra_pass"
|
||||
)
|
||||
mock_cluster_class.assert_called_once()
|
||||
call_kwargs = mock_cluster_class.call_args[1]
|
||||
assert 'auth_provider' in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_during_insert(self, processor_with_mocks):
|
||||
"""Test error handling when insertion fails"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["test"] = RowSchema(
|
||||
name="test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Make insert fail
|
||||
mock_result = MagicMock()
|
||||
mock_result.one.return_value = MagicMock() # Keyspace exists
|
||||
mock_session.execute.side_effect = [
|
||||
mock_result, # keyspace existence check succeeds
|
||||
None, # table creation succeeds
|
||||
Exception("Connection timeout") # insert fails
|
||||
]
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test",
|
||||
values=[{"id": "123"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Connection timeout"):
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_partitioning(self, processor_with_mocks):
|
||||
"""Test that objects are properly partitioned by collection"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["data"] = RowSchema(
|
||||
name="data",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process objects from different collections
|
||||
collections = ["import_jan", "import_feb", "import_mar"]
|
||||
|
||||
# Create all collections first
|
||||
for coll in collections:
|
||||
await processor.create_collection("analytics", coll, {})
|
||||
|
||||
for coll in collections:
|
||||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
schema_name="data",
|
||||
values=[{"id": f"ID-{coll}"}],
|
||||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify all inserts include collection in values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has the correct collection
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert collections[i] in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test_user", "batch_import", {})
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_batch_customers" in str(table_calls[0])
|
||||
|
||||
# Verify multiple inserts for batch values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
# Should have 3 separate inserts for the 3 objects in the batch
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has correct data
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert "batch_import" in values # collection
|
||||
assert f"CUST00{i+1}" in values # customer_id
|
||||
if i == 0:
|
||||
assert "John Doe" in values
|
||||
assert "john@example.com" in values
|
||||
elif i == 1:
|
||||
assert "Jane Smith" in values
|
||||
assert "jane@example.com" in values
|
||||
elif i == 2:
|
||||
assert "Bob Johnson" in values
|
||||
assert "bob@example.com" in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "empty", {})
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
|
||||
# Should not create any insert statements for empty batch
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
|
||||
"""Test processing mix of single and batch objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["mixed_test"] = RowSchema(
|
||||
name="mixed_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "mixed", {})
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with single item
|
||||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[
|
||||
{"id": "batch-1", "data": "batch data 1"},
|
||||
{"id": "batch-2", "data": "batch data 2"}
|
||||
],
|
||||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 total inserts (1 + 2)
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
492
tests/integration/test_rows_cassandra_integration.py
Normal file
492
tests/integration/test_rows_cassandra_integration.py
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
"""
|
||||
Integration tests for Cassandra Row Storage (Unified Table Implementation)
|
||||
|
||||
These tests verify the end-to-end functionality of storing ExtractedObjects
|
||||
in the unified Cassandra rows table, including table creation, data insertion,
|
||||
and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_cluster(self, mock_cassandra_session):
|
||||
"""Mock Cassandra cluster"""
|
||||
cluster = MagicMock()
|
||||
cluster.connect.return_value = mock_cassandra_session
|
||||
cluster.shutdown = MagicMock()
|
||||
return cluster
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
||||
"""Create processor with mocked Cassandra dependencies"""
|
||||
processor = MagicMock()
|
||||
processor.cassandra_host = ["localhost"]
|
||||
processor.cassandra_username = None
|
||||
processor.cassandra_password = None
|
||||
processor.config_key = "schema"
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.tables_initialized = set()
|
||||
processor.registered_partitions = set()
|
||||
processor.cluster = None
|
||||
processor.session = None
|
||||
|
||||
# Bind actual methods from the new unified table implementation
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
||||
"""Test complete flow from schema config to object storage"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Step 1: Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer information",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True},
|
||||
{"name": "age", "type": "integer"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
||||
# Verify keyspace creation
|
||||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
|
||||
# Verify unified table creation (rows table, not per-schema table)
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2 # rows table + row_partitions table
|
||||
assert any("rows" in str(call) for call in table_calls)
|
||||
assert any("row_partitions" in str(call) for call in table_calls)
|
||||
|
||||
# Verify the rows table has correct structure
|
||||
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
|
||||
assert "collection text" in str(rows_table_call)
|
||||
assert "schema_name text" in str(rows_table_call)
|
||||
assert "index_name text" in str(rows_table_call)
|
||||
assert "data map<text, text>" in str(rows_table_call)
|
||||
|
||||
# Verify data insertion into unified table
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
|
||||
assert len(rows_insert_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_schema_handling(self, processor_with_mocks):
|
||||
"""Test handling multiple schemas stored in unified table"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure multiple schemas
|
||||
config = {
|
||||
"schema": {
|
||||
"products": json.dumps({
|
||||
"name": "products",
|
||||
"fields": [
|
||||
{"name": "product_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "price", "type": "float"}
|
||||
]
|
||||
}),
|
||||
"orders": json.dumps({
|
||||
"name": "orders",
|
||||
"fields": [
|
||||
{"name": "order_id", "type": "string", "primary_key": True},
|
||||
{"name": "customer_id", "type": "string"},
|
||||
{"name": "total", "type": "float"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
|
||||
assert len(table_calls) == 2
|
||||
|
||||
# Verify data inserts go to unified rows table
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) > 0
|
||||
for call in rows_insert_calls:
|
||||
assert ".rows" in str(call)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_index_storage(self, processor_with_mocks):
|
||||
"""Test that rows are stored with multiple indexes"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
"category": "electronics",
|
||||
"status": "active",
|
||||
"description": "A product"
|
||||
}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 3
|
||||
|
||||
# Verify different index names were used
|
||||
index_names = set()
|
||||
for call in rows_insert_calls:
|
||||
values = call[0][1]
|
||||
index_names.add(values[2]) # index_name is 3rd parameter
|
||||
|
||||
assert index_names == {"id", "category", "status"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_handling(self, processor_with_mocks):
|
||||
"""Test Cassandra authentication"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
processor.cassandra_username = "cassandra_user"
|
||||
processor.cassandra_password = "cassandra_pass"
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
|
||||
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||
mock_cluster_class.return_value = mock_cluster
|
||||
|
||||
# Trigger connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify authentication was configured
|
||||
mock_auth.assert_called_once_with(
|
||||
username="cassandra_user",
|
||||
password="cassandra_pass"
|
||||
)
|
||||
mock_cluster_class.assert_called_once()
|
||||
call_kwargs = mock_cluster_class.call_args[1]
|
||||
assert 'auth_provider' in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2 # rows + row_partitions
|
||||
|
||||
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
|
||||
# 3 rows * 2 indexes = 6 data inserts
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_stored_as_map(self, processor_with_mocks):
|
||||
"""Test that data is stored as map<text, text>"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) >= 1
|
||||
|
||||
# Check that data is passed as a dict (will be map in Cassandra)
|
||||
insert_call = rows_insert_calls[0]
|
||||
values = insert_call[0][1]
|
||||
# Values are: (collection, schema_name, index_name, index_value, data, source)
|
||||
# values[4] should be the data map
|
||||
data_map = values[4]
|
||||
assert isinstance(data_map, dict)
|
||||
assert data_map["id"] == "123"
|
||||
assert data_map["name"] == "Test Item"
|
||||
assert data_map["count"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partition_registration(self, processor_with_mocks):
|
||||
"""Test that partitions are registered for efficient querying"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
|
||||
# Should register partitions for each index (id, category)
|
||||
assert len(partition_inserts) == 2
|
||||
|
||||
# Verify cache was updated
|
||||
assert ("my_collection", "partition_test") in processor.registered_partitions
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Integration tests for Objects GraphQL Query Service
|
||||
Integration tests for Rows GraphQL Query Service
|
||||
|
||||
These tests verify end-to-end functionality including:
|
||||
- Real Cassandra database operations
|
||||
|
|
@ -24,8 +24,8 @@ except Exception:
|
|||
DOCKER_AVAILABLE = False
|
||||
CassandraContainer = None
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
|
||||
|
||||
|
||||
|
|
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
request = ObjectsQueryRequest(
|
||||
request = RowsQueryRequest(
|
||||
user="msg_test_user",
|
||||
collection="msg_test_collection",
|
||||
query='{ customer_objects { customer_id name } }',
|
||||
|
|
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
|
||||
# Verify response structure
|
||||
sent_response = mock_response_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, ObjectsQueryResponse)
|
||||
assert isinstance(sent_response, RowsQueryResponse)
|
||||
|
||||
# Should have no system error (even if no data)
|
||||
assert sent_response.error is None
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
Integration tests for Structured Query Service
|
||||
|
||||
These tests verify the end-to-end functionality of the structured query service,
|
||||
testing orchestration between nlp-query and objects-query services.
|
||||
testing orchestration between nlp-query and rows-query services.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
RowsQueryRequest, RowsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
|
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects Query Service Response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
|
||||
errors=None,
|
||||
|
|
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
# Verify Objects service call
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert "customers" in objects_call_args.query
|
||||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
|
|
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects service failure
|
||||
objects_error_response = ObjectsQueryResponse(
|
||||
objects_error_response = RowsQueryResponse(
|
||||
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
|
||||
data=None,
|
||||
errors=None,
|
||||
|
|
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
|
||||
assert response.error is not None
|
||||
assert response.error.type == "structured-query-error"
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "Rows query service error" in response.error.message
|
||||
assert "nonexistent_table" in response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None, # No data when validation fails
|
||||
errors=validation_errors,
|
||||
|
|
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
]
|
||||
}
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(complex_data),
|
||||
errors=None,
|
||||
|
|
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock empty Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": []}', # Empty result set
|
||||
errors=None,
|
||||
|
|
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
||||
errors=None,
|
||||
|
|
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
if service_name == "nlp-query-request":
|
||||
service_call_count += 1
|
||||
return nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
service_call_count += 1
|
||||
return objects_client
|
||||
elif service_name == "response":
|
||||
|
|
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
|
||||
errors=None,
|
||||
|
|
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue