mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-26 15:55:16 +02:00
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.
Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
captures the workspace/collection/flow hierarchy.
Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
service layer.
- Translators updated to not serialise/deserialise user.
API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.
Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
scoped by workspace. Config client API takes workspace as first
positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.
CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
library) drop user kwargs from every method signature.
MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
keyed per user.
Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
whose blueprint template was parameterised AND no remaining
live flow (across all workspaces) still resolves to that topic.
Three scopes fall out naturally from template analysis:
* {id} -> per-flow, deleted on stop
* {blueprint} -> per-blueprint, kept while any flow of the
same blueprint exists
* {workspace} -> per-workspace, kept while any flow in the
workspace exists
* literal -> global, never deleted (e.g. tg.request.librarian)
Fixes a bug where stopping a flow silently destroyed the global
librarian exchange, wedging all library operations until manual
restart.
RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
dead connections (broker restart, orphaned channels, network
partitions) within ~2 heartbeat windows, so the consumer
reconnects and re-binds its queue rather than sitting forever
on a zombie connection.
Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
parent
9332089b3d
commit
d35473f7f7
377 changed files with 6868 additions and 5785 deletions
|
|
@ -72,7 +72,6 @@ def sample_message_data():
|
|||
},
|
||||
"DocumentRagQuery": {
|
||||
"query": "What is artificial intelligence?",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection",
|
||||
"doc_limit": 10
|
||||
},
|
||||
|
|
@ -95,7 +94,6 @@ def sample_message_data():
|
|||
},
|
||||
"Metadata": {
|
||||
"id": "test-doc-123",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection"
|
||||
},
|
||||
"Term": {
|
||||
|
|
@ -130,9 +128,8 @@ def invalid_message_data():
|
|||
{}, # Missing required fields
|
||||
],
|
||||
"DocumentRagQuery": [
|
||||
{"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query
|
||||
{"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user
|
||||
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
|
||||
{"query": None, "collection": "test", "doc_limit": 10}, # Invalid query
|
||||
{"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
|
||||
{"query": "test"}, # Missing required fields
|
||||
],
|
||||
"Term": [
|
||||
|
|
|
|||
|
|
@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
|
||||
def test_request_schema_fields(self):
|
||||
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
||||
# Create a request
|
||||
request = DocumentEmbeddingsRequest(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify all expected fields exist
|
||||
assert hasattr(request, 'vector')
|
||||
assert hasattr(request, 'limit')
|
||||
assert hasattr(request, 'user')
|
||||
assert hasattr(request, 'collection')
|
||||
|
||||
# Verify field values
|
||||
assert request.vector == [0.1, 0.2, 0.3]
|
||||
assert request.limit == 10
|
||||
assert request.user == "test_user"
|
||||
assert request.collection == "test_collection"
|
||||
|
||||
def test_request_translator_decode(self):
|
||||
|
|
@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
data = {
|
||||
"vector": [0.1, 0.2, 0.3, 0.4],
|
||||
"limit": 5,
|
||||
"user": "custom_user",
|
||||
"collection": "custom_collection"
|
||||
}
|
||||
|
||||
|
|
@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||
assert result.limit == 5
|
||||
assert result.user == "custom_user"
|
||||
assert result.collection == "custom_collection"
|
||||
|
||||
def test_request_translator_decode_with_defaults(self):
|
||||
|
|
@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
|
||||
data = {
|
||||
"vector": [0.1, 0.2]
|
||||
# No limit, user, or collection provided
|
||||
# No limit or collection provided
|
||||
}
|
||||
|
||||
result = translator.decode(data)
|
||||
|
|
@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2]
|
||||
assert result.limit == 10 # Default
|
||||
assert result.user == "trustgraph" # Default
|
||||
assert result.collection == "default" # Default
|
||||
|
||||
def test_request_translator_encode(self):
|
||||
|
|
@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
request = DocumentEmbeddingsRequest(
|
||||
vector=[0.5, 0.6],
|
||||
limit=20,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, dict)
|
||||
assert result["vector"] == [0.5, 0.6]
|
||||
assert result["limit"] == 20
|
||||
assert result["user"] == "test_user"
|
||||
assert result["collection"] == "test_collection"
|
||||
|
||||
|
||||
|
|
@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
request_data = {
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"limit": 5,
|
||||
"user": "test_user",
|
||||
"collection": "test_collection"
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts:
|
|||
# Test required fields
|
||||
query = DocumentRagQuery(**query_data)
|
||||
assert hasattr(query, 'query')
|
||||
assert hasattr(query, 'user')
|
||||
assert hasattr(query, 'collection')
|
||||
assert hasattr(query, 'doc_limit')
|
||||
|
||||
|
|
@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts:
|
|||
# Test valid query
|
||||
valid_query = DocumentRagQuery(
|
||||
query="What is AI?",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5
|
||||
)
|
||||
assert valid_query.query == "What is AI?"
|
||||
assert valid_query.user == "test_user"
|
||||
assert valid_query.collection == "test_collection"
|
||||
assert valid_query.doc_limit == 5
|
||||
|
||||
|
|
@ -400,7 +397,6 @@ class TestMetadataMessageContracts:
|
|||
|
||||
metadata = Metadata(**metadata_data)
|
||||
assert metadata.id == "test-doc-123"
|
||||
assert metadata.user == "test_user"
|
||||
assert metadata.collection == "test_collection"
|
||||
|
||||
def test_error_schema_contract(self):
|
||||
|
|
@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts:
|
|||
required_fields = {
|
||||
"TextCompletionRequest": ["system", "prompt"],
|
||||
"TextCompletionResponse": ["error", "response", "model"],
|
||||
"DocumentRagQuery": ["query", "user", "collection"],
|
||||
"DocumentRagQuery": ["query", "collection"],
|
||||
"DocumentRagResponse": ["error", "response"],
|
||||
"AgentRequest": ["question", "history"],
|
||||
"AgentResponse": ["error"],
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts:
|
|||
def test_agent_request_orchestration_fields_roundtrip(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
|
|
@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts:
|
|||
def test_agent_request_orchestration_fields_default_empty(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
)
|
||||
|
||||
assert req.correlation_id == ""
|
||||
|
|
@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract:
|
|||
)
|
||||
req = AgentRequest(
|
||||
question="goal",
|
||||
user="testuser",
|
||||
correlation_id="corr-123",
|
||||
history=[step],
|
||||
)
|
||||
|
|
@ -126,7 +123,6 @@ class TestSynthesisStepContract:
|
|||
|
||||
req = AgentRequest(
|
||||
question="Original question",
|
||||
user="testuser",
|
||||
pattern="supervisor",
|
||||
correlation_id="",
|
||||
session_id="parent-sess",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class TestRowsCassandraContracts:
|
|||
# Create test object with all required fields
|
||||
test_metadata = Metadata(
|
||||
id="test-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -47,7 +46,6 @@ class TestRowsCassandraContracts:
|
|||
|
||||
# Verify metadata structure
|
||||
assert hasattr(test_object.metadata, 'id')
|
||||
assert hasattr(test_object.metadata, 'user')
|
||||
assert hasattr(test_object.metadata, 'collection')
|
||||
|
||||
# Verify types
|
||||
|
|
@ -150,7 +148,6 @@ class TestRowsCassandraContracts:
|
|||
original = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -168,7 +165,6 @@ class TestRowsCassandraContracts:
|
|||
|
||||
# Verify round-trip
|
||||
assert decoded.metadata.id == original.metadata.id
|
||||
assert decoded.metadata.user == original.metadata.user
|
||||
assert decoded.metadata.collection == original.metadata.collection
|
||||
assert decoded.schema_name == original.schema_name
|
||||
assert decoded.values == original.values
|
||||
|
|
@ -228,8 +224,7 @@ class TestRowsCassandraContracts:
|
|||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="meta-001",
|
||||
user="user123", # -> keyspace
|
||||
id="meta-001", # -> keyspace
|
||||
collection="coll456", # -> partition key
|
||||
),
|
||||
schema_name="table789", # -> table name
|
||||
|
|
@ -242,7 +237,6 @@ class TestRowsCassandraContracts:
|
|||
# - metadata.user -> Cassandra keyspace
|
||||
# - schema_name -> Cassandra table
|
||||
# - metadata.collection -> Part of primary key
|
||||
assert test_obj.metadata.user # Required for keyspace
|
||||
assert test_obj.schema_name # Required for table
|
||||
assert test_obj.metadata.collection # Required for partition key
|
||||
|
||||
|
|
@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch:
|
|||
# Create test object with multiple values in batch
|
||||
test_metadata = Metadata(
|
||||
id="batch-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch:
|
|||
"""Test empty batch ExtractedObject contract"""
|
||||
test_metadata = Metadata(
|
||||
id="empty-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch:
|
|||
"""Test single-item batch (backward compatibility) contract"""
|
||||
test_metadata = Metadata(
|
||||
id="single-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch:
|
|||
original = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch:
|
|||
|
||||
# Verify round-trip for batch
|
||||
assert decoded.metadata.id == original.metadata.id
|
||||
assert decoded.metadata.user == original.metadata.user
|
||||
assert decoded.metadata.collection == original.metadata.collection
|
||||
assert decoded.schema_name == original.schema_name
|
||||
assert len(decoded.values) == len(original.values)
|
||||
|
|
@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch:
|
|||
# 3. Be stored in the same keyspace (user)
|
||||
|
||||
test_metadata = Metadata(
|
||||
id="partition-test-001",
|
||||
user="consistent_user", # Same keyspace
|
||||
id="partition-test-001", # Same keyspace
|
||||
collection="consistent_collection", # Same partition
|
||||
)
|
||||
|
||||
|
|
@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch:
|
|||
)
|
||||
|
||||
# Verify consistency contract
|
||||
assert batch_object.metadata.user # Must have user for keyspace
|
||||
assert batch_object.metadata.collection # Must have collection for partition key
|
||||
|
||||
# Verify unique primary keys in batch
|
||||
|
|
|
|||
|
|
@ -21,29 +21,25 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test RowsQueryRequest schema structure and required fields"""
|
||||
# Create test request with all required fields
|
||||
test_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name email } }',
|
||||
variables={"status": "active", "limit": "10"},
|
||||
operation_name="GetCustomers"
|
||||
)
|
||||
|
||||
|
||||
# Verify all required fields are present
|
||||
assert hasattr(test_request, 'user')
|
||||
assert hasattr(test_request, 'collection')
|
||||
assert hasattr(test_request, 'collection')
|
||||
assert hasattr(test_request, 'query')
|
||||
assert hasattr(test_request, 'variables')
|
||||
assert hasattr(test_request, 'operation_name')
|
||||
|
||||
|
||||
# Verify field types
|
||||
assert isinstance(test_request.user, str)
|
||||
assert isinstance(test_request.collection, str)
|
||||
assert isinstance(test_request.query, str)
|
||||
assert isinstance(test_request.variables, dict)
|
||||
assert isinstance(test_request.operation_name, str)
|
||||
|
||||
|
||||
# Verify content
|
||||
assert test_request.user == "test_user"
|
||||
assert test_request.collection == "test_collection"
|
||||
assert "customers" in test_request.query
|
||||
assert test_request.variables["status"] == "active"
|
||||
|
|
@ -53,15 +49,13 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test RowsQueryRequest with minimal required fields"""
|
||||
# Create request with only essential fields
|
||||
minimal_request = RowsQueryRequest(
|
||||
user="user",
|
||||
collection="collection",
|
||||
query='{ test }',
|
||||
variables={},
|
||||
operation_name=""
|
||||
)
|
||||
|
||||
|
||||
# Verify minimal request is valid
|
||||
assert minimal_request.user == "user"
|
||||
assert minimal_request.collection == "collection"
|
||||
assert minimal_request.query == '{ test }'
|
||||
assert minimal_request.variables == {}
|
||||
|
|
@ -187,22 +181,20 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test that request/response can be serialized/deserialized correctly"""
|
||||
# Create original request
|
||||
original_request = RowsQueryRequest(
|
||||
user="serialization_test",
|
||||
collection="test_data",
|
||||
query='{ orders(limit: 5) { id total customer { name } } }',
|
||||
variables={"limit": "5", "status": "active"},
|
||||
operation_name="GetRecentOrders"
|
||||
)
|
||||
|
||||
|
||||
# Test request serialization using Pulsar schema
|
||||
request_schema = AvroSchema(RowsQueryRequest)
|
||||
|
||||
|
||||
# Encode and decode request
|
||||
encoded_request = request_schema.encode(original_request)
|
||||
decoded_request = request_schema.decode(encoded_request)
|
||||
|
||||
|
||||
# Verify request round-trip
|
||||
assert decoded_request.user == original_request.user
|
||||
assert decoded_request.collection == original_request.collection
|
||||
assert decoded_request.query == original_request.query
|
||||
assert decoded_request.variables == original_request.variables
|
||||
|
|
@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test supported GraphQL query formats"""
|
||||
# Test basic query
|
||||
basic_query = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ customers { id } }',
|
||||
collection="test", query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
assert "customers" in basic_query.query
|
||||
|
|
@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test query with variables
|
||||
parameterized_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
|
||||
variables={"status": "active", "limit": "10"},
|
||||
operation_name="GetCustomers"
|
||||
|
|
@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test complex nested query
|
||||
nested_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='''
|
||||
{
|
||||
customers(limit: 10) {
|
||||
|
|
@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
# This test verifies the current contract, though ideally we'd support all JSON types
|
||||
|
||||
variables_test = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ test }',
|
||||
collection="test", query='{ test }',
|
||||
variables={
|
||||
"string_var": "test_value",
|
||||
"numeric_var": "123", # Numbers as strings due to Map(String()) limitation
|
||||
|
|
@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
def test_cassandra_context_fields_contract(self):
|
||||
"""Test that request contains necessary fields for Cassandra operations"""
|
||||
# Verify request has fields needed for Cassandra keyspace/table targeting
|
||||
# Verify request has fields needed for partition key targeting
|
||||
request = RowsQueryRequest(
|
||||
user="keyspace_name", # Maps to Cassandra keyspace
|
||||
collection="partition_collection", # Used in partition key
|
||||
query='{ objects { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
|
||||
# These fields are required for proper Cassandra operations
|
||||
assert request.user # Required for keyspace identification
|
||||
assert request.collection # Required for partition key
|
||||
|
||||
|
||||
# Required for partition key
|
||||
assert request.collection
|
||||
|
||||
# Verify field naming follows TrustGraph patterns (matching other query services)
|
||||
# This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns
|
||||
assert hasattr(request, 'user') # Same as TriplesQueryRequest.user
|
||||
assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection
|
||||
assert hasattr(request, 'collection')
|
||||
|
||||
def test_graphql_extensions_contract(self):
|
||||
"""Test GraphQL extensions field format and usage"""
|
||||
|
|
@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Request to execute specific operation
|
||||
multi_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query=multi_op_query,
|
||||
variables={},
|
||||
operation_name="GetCustomers"
|
||||
|
|
@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test single operation (operation_name optional)
|
||||
single_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,10 +41,11 @@ class TestSchemaFieldContracts:
|
|||
def test_metadata_fields(self):
|
||||
# NOTE: there is no `metadata` field. A previous regression
|
||||
# constructed Metadata(metadata=...) and crashed at runtime.
|
||||
# `user` was also dropped in the workspace refactor — workspace
|
||||
# now flows via flow.workspace, not via message payload.
|
||||
assert _field_names(Metadata) == {
|
||||
"id",
|
||||
"root",
|
||||
"user",
|
||||
"collection",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="structured-data-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-obj-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-empty-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="struct-embed-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_structured_data_submission_serialization(self):
|
||||
"""Test StructuredDataSubmission serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
submission_data = {
|
||||
"metadata": metadata,
|
||||
"format": "json",
|
||||
|
|
@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_serialization(self):
|
||||
"""Test ExtractedObject serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_batch_serialization(self):
|
||||
"""Test ExtractedObject batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
batch_object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_empty_batch_serialization(self):
|
||||
"""Test ExtractedObject empty batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
empty_batch_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
|
|||
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
|
||||
"""Test basic agent integration with structured query tool"""
|
||||
# Arrange - Load tool configuration
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Create agent request
|
||||
request = AgentRequest(
|
||||
|
|
@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
|
|||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -119,6 +118,7 @@ Args: {
|
|||
# Mock flow parameter in agent_processor.on_request
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -146,14 +146,13 @@ Args: {
|
|||
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent handling of structured query errors"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -199,6 +198,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -221,14 +221,13 @@ Args: {
|
|||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent using structured query in multi-step reasoning"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -279,6 +278,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -313,14 +313,13 @@ Args: {
|
|||
}
|
||||
}
|
||||
|
||||
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
|
||||
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -371,6 +370,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -394,10 +394,10 @@ Args: {
|
|||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query tool arguments are properly validated"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Check that the tool was registered with correct arguments
|
||||
tools = agent_processor.agent.tools
|
||||
tools = agent_processor.agents["default"].tools
|
||||
assert "structured-query" in tools
|
||||
|
||||
structured_tool = tools["structured-query"]
|
||||
|
|
@ -414,14 +414,13 @@ Args: {
|
|||
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query results are properly formatted for agent consumption"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -482,6 +481,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
|
|||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Mock query to trigger TrustGraph creation
|
||||
mock_query = MagicMock()
|
||||
mock_query.user = 'default_user'
|
||||
mock_query.collection = 'default_collection'
|
||||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
|
|
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
await processor.query_triples(mock_query)
|
||||
await processor.query_triples('default_user', mock_query)
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('legacy_user', mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('precedence_user', mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('single_user', mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create test message
|
||||
storage_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
|
|
@ -178,7 +178,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Store test data for querying
|
||||
query_test_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
|
|
@ -212,7 +212,6 @@ class TestCassandraIntegration:
|
|||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
s_results = await query_processor.query_triples(s_query)
|
||||
|
|
@ -232,7 +231,6 @@ class TestCassandraIntegration:
|
|||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
p_results = await query_processor.query_triples(p_query)
|
||||
|
|
@ -259,7 +257,7 @@ class TestCassandraIntegration:
|
|||
# Create multiple coroutines for concurrent storage
|
||||
async def store_person_data(person_id, name, age, department):
|
||||
message = Triples(
|
||||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
metadata=Metadata(collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
|
|
@ -329,7 +327,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create a knowledge graph about a company
|
||||
company_graph = Triples(
|
||||
metadata=Metadata(user="integration_test", collection="company"),
|
||||
metadata=Metadata(collection="company"),
|
||||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
|
|
|
|||
|
|
@ -99,7 +99,6 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
result = await document_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
|
@ -110,7 +109,6 @@ class TestDocumentRagIntegration:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
|
@ -278,14 +276,12 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -353,6 +349,5 @@ class TestDocumentRagIntegration:
|
|||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -141,7 +140,6 @@ class TestDocumentRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=False
|
||||
|
|
@ -155,7 +153,6 @@ class TestDocumentRagStreaming:
|
|||
|
||||
streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
|
|
@ -178,7 +175,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -200,7 +196,6 @@ class TestDocumentRagStreaming:
|
|||
# Arrange & Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -223,7 +218,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -247,7 +241,6 @@ class TestDocumentRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -272,7 +265,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=limit,
|
||||
streaming=True,
|
||||
|
|
@ -300,7 +292,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -309,5 +300,4 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Verify user/collection were passed to document embeddings client
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
|
|
|||
|
|
@ -146,7 +146,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
@ -163,7 +162,6 @@ class TestGraphRagIntegration:
|
|||
call_args = mock_graph_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['limit'] == entity_limit
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
|
|
@ -204,7 +202,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=config["entity_limit"],
|
||||
triple_limit=config["triple_limit"]
|
||||
|
|
@ -224,7 +221,6 @@ class TestGraphRagIntegration:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -247,7 +243,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
|
@ -267,7 +262,6 @@ class TestGraphRagIntegration:
|
|||
# First query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -277,7 +271,6 @@ class TestGraphRagIntegration:
|
|||
# Second identical query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -289,26 +282,27 @@ class TestGraphRagIntegration:
|
|||
assert second_call_count >= 0 # Should complete without errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different users/collections are properly isolated"""
|
||||
async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different collections propagate through to the embeddings query.
|
||||
|
||||
Workspace isolation is enforced by flow.workspace at the service
|
||||
boundary — not by parameters on GraphRag.query — so this test
|
||||
verifies collection routing only.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
user1, collection1 = "user1", "collection1"
|
||||
user2, collection2 = "user2", "collection2"
|
||||
collection1 = "collection1"
|
||||
collection2 = "collection2"
|
||||
|
||||
# Act
|
||||
await graph_rag.query(query=query, user=user1, collection=collection1)
|
||||
await graph_rag.query(query=query, user=user2, collection=collection2)
|
||||
await graph_rag.query(query=query, collection=collection1)
|
||||
await graph_rag.query(query=query, collection=collection2)
|
||||
|
||||
# Assert - Both users should have separate queries
|
||||
# Assert - Each call propagated its collection
|
||||
assert mock_graph_embeddings_client.query.call_count == 2
|
||||
|
||||
# Verify first call
|
||||
first_call = mock_graph_embeddings_client.query.call_args_list[0]
|
||||
assert first_call.kwargs['user'] == user1
|
||||
assert first_call.kwargs['collection'] == collection1
|
||||
|
||||
# Verify second call
|
||||
second_call = mock_graph_embeddings_client.query.call_args_list[1]
|
||||
assert second_call.kwargs['user'] == user2
|
||||
assert second_call.kwargs['collection'] == collection2
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ class TestGraphRagStreaming:
|
|||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect,
|
||||
|
|
@ -154,7 +153,6 @@ class TestGraphRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=False
|
||||
)
|
||||
|
|
@ -167,7 +165,6 @@ class TestGraphRagStreaming:
|
|||
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -189,7 +186,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -209,7 +205,6 @@ class TestGraphRagStreaming:
|
|||
# Arrange & Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=None # No callback provided
|
||||
|
|
@ -231,7 +226,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -253,7 +247,6 @@ class TestGraphRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -273,7 +266,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
triples_obj = Triples(
|
||||
metadata=Metadata(
|
||||
id=f"export-msg-{i}",
|
||||
user=msg_data["metadata"]["user"],
|
||||
collection=msg_data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(msg_data["triples"]),
|
||||
|
|
|
|||
|
|
@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
|
|
@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_triples = Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_triples
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_triples(mock_msg, None, None)
|
||||
await processor.on_triples(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
|
||||
mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
|
||||
|
|
@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_graph_embeddings(mock_msg, None, None)
|
||||
await processor.on_graph_embeddings(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
|
||||
|
|
@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
metadata=Metadata(id="test", collection="collection"),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
|
|
@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
|
||||
metadata=Metadata(id=f"doc-{i}", collection="collection"),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
|
|
@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
original_metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
|
||||
|
||||
assert triples_call.metadata.id == "test-doc-123"
|
||||
assert triples_call.metadata.user == "test_user"
|
||||
assert triples_call.metadata.collection == "test_collection"
|
||||
|
||||
assert entity_contexts_call.metadata.id == "test-doc-123"
|
||||
assert entity_contexts_call.metadata.user == "test_user"
|
||||
assert entity_contexts_call.metadata.collection == "test_collection"
|
||||
|
|
@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Set up schemas
|
||||
proc.schemas = sample_schemas
|
||||
proc.schemas = {"default": dict(sample_schemas)}
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
|
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
|
|||
}
|
||||
|
||||
# Act - Update configuration
|
||||
await integration_processor.on_schema_config(new_schema_config, "v2")
|
||||
await integration_processor.on_schema_config("default", new_schema_config, "v2")
|
||||
|
||||
# Arrange - Test query using new schema
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
|
|||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert "inventory" in integration_processor.schemas
|
||||
assert "inventory" in integration_processor.schemas["default"]
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.detected_schemas == ["inventory"]
|
||||
|
|
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
|
|||
graphql_generation_template="custom-graphql-generator"
|
||||
)
|
||||
|
||||
custom_processor.schemas = sample_schemas
|
||||
custom_processor.schemas = {"default": dict(sample_schemas)}
|
||||
custom_processor.client = MagicMock()
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
|
|||
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
|
||||
)
|
||||
|
||||
integration_processor.schemas.update(large_schema_set)
|
||||
integration_processor.schemas["default"].update(large_schema_set)
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me data from table_05 and table_12",
|
||||
|
|
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
|
|||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
# Verify customer schema
|
||||
customer_schema = processor.schemas["customer_records"]
|
||||
customer_schema = ws_schemas["customer_records"]
|
||||
assert customer_schema.name == "customer_records"
|
||||
assert len(customer_schema.fields) == 4
|
||||
|
||||
|
||||
# Verify product schema
|
||||
product_schema = processor.schemas["product_catalog"]
|
||||
product_schema = ws_schemas["product_catalog"]
|
||||
assert product_schema.name == "product_catalog"
|
||||
assert len(product_schema.fields) == 4
|
||||
|
||||
|
|
@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic customer data chunk
|
||||
metadata = Metadata(
|
||||
id="customer-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic product data chunk
|
||||
metadata = Metadata(
|
||||
id="product-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create multiple test chunks
|
||||
chunks_data = [
|
||||
|
|
@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
for chunk_id, text in chunks_data:
|
||||
metadata = Metadata(
|
||||
id=chunk_id,
|
||||
user="concurrent_test",
|
||||
collection="test_collection",
|
||||
)
|
||||
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
|
||||
|
|
@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
"customer_records": integration_config["schema"]["customer_records"]
|
||||
}
|
||||
}
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
|
||||
assert len(processor.schemas) == 1
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" not in processor.schemas
|
||||
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 1
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" not in ws_schemas
|
||||
|
||||
# Act - Reload with full configuration
|
||||
await processor.on_schema_config(integration_config, version=2)
|
||||
|
||||
await processor.on_schema_config("default", integration_config, version=2)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_resilience_integration(self, integration_config):
|
||||
|
|
@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
failing_flow.side_effect = failing_context_router
|
||||
failing_flow.workspace = "default"
|
||||
processor.flow = failing_flow
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create test chunk
|
||||
metadata = Metadata(id="error-test", user="test", collection="test")
|
||||
metadata = Metadata(id="error-test", collection="test")
|
||||
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create chunk with rich metadata
|
||||
original_metadata = Metadata(
|
||||
id="metadata-test-chunk",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert extracted_obj is not None
|
||||
|
||||
# Verify metadata propagation
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
assert extracted_obj.metadata.collection == "test_collection"
|
||||
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference
|
||||
|
|
@ -87,6 +87,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -109,7 +110,7 @@ class TestPromptStreaming:
|
|||
def prompt_processor_streaming(self, mock_prompt_manager):
|
||||
"""Create Prompt processor with streaming support"""
|
||||
processor = MagicMock()
|
||||
processor.manager = mock_prompt_manager
|
||||
processor.managers = {"default": mock_prompt_manager}
|
||||
processor.config_key = "prompt"
|
||||
|
||||
# Bind the actual on_request method
|
||||
|
|
@ -248,6 +249,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
@ -341,6 +343,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
|
|||
|
|
@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
|
@ -125,14 +136,13 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
),
|
||||
schema_name="customer_records",
|
||||
|
|
@ -149,7 +159,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
|
@ -158,7 +168,7 @@ class TestRowsCassandraIntegration:
|
|||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
assert "default" in str(keyspace_calls[0])
|
||||
|
||||
# Verify unified table creation (rows table, not per-schema table)
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -209,12 +219,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert len(processor.schemas["default"]) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog"),
|
||||
metadata=Metadata(id="p1", collection="catalog"),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -222,7 +232,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales"),
|
||||
metadata=Metadata(id="o1", collection="sales"),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
|
|
@ -233,7 +243,7 @@ class TestRowsCassandraIntegration:
|
|||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -256,18 +266,20 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"indexed_data": RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
|
|
@ -282,7 +294,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -342,13 +354,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
|
|
@ -376,7 +387,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -396,14 +407,16 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"empty_test": RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty"),
|
||||
metadata=Metadata(id="empty-1", collection="empty"),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
|
|
@ -413,7 +426,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
|
|
@ -428,17 +441,19 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"map_test": RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -448,7 +463,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -473,16 +488,18 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"partition_test": RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection"),
|
||||
metadata=Metadata(id="t1", collection="my_collection"),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -492,7 +509,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
|
||||
"""Test schema configuration loading and GraphQL schema generation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Verify schemas were loaded
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
|
||||
"""Test Cassandra connection and dynamic table creation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
|
||||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
|
|
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
|
||||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_error_handling(self, processor, sample_schema_config):
|
||||
"""Test GraphQL error handling for invalid queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Test invalid field query
|
||||
invalid_query = '''
|
||||
|
|
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_message_processing_integration(self, processor, sample_schema_config):
|
||||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
|
|
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_concurrent_queries(self, processor, sample_schema_config):
|
||||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
|
|
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
assert len(processor.schemas) == 1
|
||||
assert "simple" in processor.schemas
|
||||
|
||||
|
|
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(updated_config, version=2)
|
||||
await processor.on_schema_config("default", updated_config, version=2)
|
||||
|
||||
# Verify updated schemas
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_large_result_set_handling(self, processor, sample_schema_config):
|
||||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
|
|
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Measure query execution time
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
# Arrange - Create realistic query request
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from California who have made purchases over $500",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
assert objects_call_args.variables["state"] == "California"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
|
|||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
def _make_request(question="Test question",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
|
|
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
|
|||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
|
|
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
|
|
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
|
|||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
"unknown", "question", "default",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
|
|||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
|
|||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ def make_base_request(**kwargs):
|
|||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class MockProcessor:
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -167,39 +167,28 @@ class TestToolServiceRequest:
|
|||
"""Test cases for tool service request format"""
|
||||
|
||||
def test_request_format(self):
|
||||
"""Test that request is properly formatted with user, config, and arguments"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
"""Test that request is properly formatted with config and arguments"""
|
||||
config_values = {"style": "pun", "collection": "jokes"}
|
||||
arguments = {"topic": "programming"}
|
||||
|
||||
# Act - simulate request building
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values),
|
||||
"arguments": json.dumps(arguments)
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["user"] == "alice"
|
||||
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
|
||||
assert json.loads(request["arguments"]) == {"topic": "programming"}
|
||||
|
||||
def test_request_with_empty_config(self):
|
||||
"""Test request when no config values are provided"""
|
||||
# Arrange
|
||||
user = "bob"
|
||||
config_values = {}
|
||||
arguments = {"query": "test"}
|
||||
|
||||
# Act
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values) if config_values else "{}",
|
||||
"arguments": json.dumps(arguments) if arguments else "{}"
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["config"] == "{}"
|
||||
assert json.loads(request["arguments"]) == {"query": "test"}
|
||||
|
||||
|
|
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
|
|||
assert map_topic_to_category("random topic") == "default"
|
||||
assert map_topic_to_category("") == "default"
|
||||
|
||||
def test_joke_response_personalization(self):
|
||||
"""Test that joke responses include user personalization"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
def test_joke_response_format(self):
|
||||
"""Test that joke response is formatted as expected"""
|
||||
style = "pun"
|
||||
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
|
||||
|
||||
# Act
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
response = f"Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
# Assert
|
||||
assert "Hey alice!" in response
|
||||
assert "pun" in response
|
||||
assert joke in response
|
||||
|
||||
|
|
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
|
|||
|
||||
def test_request_parsing(self):
|
||||
"""Test parsing of incoming request"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
"user": "alice",
|
||||
"config": '{"style": "pun"}',
|
||||
"arguments": '{"topic": "programming"}'
|
||||
}
|
||||
|
||||
# Act
|
||||
user = request_data.get("user", "trustgraph")
|
||||
config = json.loads(request_data["config"]) if request_data["config"] else {}
|
||||
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
|
||||
|
||||
# Assert
|
||||
assert user == "alice"
|
||||
assert config == {"style": "pun"}
|
||||
assert arguments == {"topic": "programming"}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Tests for tool service lifecycle, invoke contract, streaming responses,
|
||||
multi-tenancy, and error propagation.
|
||||
and error propagation.
|
||||
|
||||
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
|
||||
classes rather than plain dicts.
|
||||
|
|
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await svc.invoke("user", {}, {})
|
||||
await svc.invoke({}, {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_calls_invoke_with_parsed_args(self):
|
||||
|
|
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
calls = []
|
||||
|
||||
async def tracking_invoke(user, config, arguments):
|
||||
calls.append({"user": user, "config": config, "arguments": arguments})
|
||||
async def tracking_invoke(config, arguments):
|
||||
calls.append({"config": config, "arguments": arguments})
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking_invoke
|
||||
|
|
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user="alice",
|
||||
config='{"style": "pun"}',
|
||||
arguments='{"topic": "cats"}',
|
||||
)
|
||||
|
|
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
|
|||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["user"] == "alice"
|
||||
assert calls[0]["config"] == {"style": "pun"}
|
||||
assert calls[0]["arguments"] == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_empty_user_defaults_to_trustgraph(self):
|
||||
"""Empty user field should default to 'trustgraph'."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
received_user = None
|
||||
|
||||
async def capture_invoke(user, config, arguments):
|
||||
nonlocal received_user
|
||||
received_user = user
|
||||
return "ok"
|
||||
|
||||
svc.invoke = capture_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
|
||||
msg.properties.return_value = {"id": "req-2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert received_user == "trustgraph"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_string_response_sent_directly(self):
|
||||
"""String return from invoke → response field is the string."""
|
||||
|
|
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def string_invoke(user, config, arguments):
|
||||
async def string_invoke(config, arguments):
|
||||
return "hello world"
|
||||
|
||||
svc.invoke = string_invoke
|
||||
|
|
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def dict_invoke(user, config, arguments):
|
||||
async def dict_invoke(config, arguments):
|
||||
return {"result": 42}
|
||||
|
||||
svc.invoke = dict_invoke
|
||||
|
|
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def failing_invoke(user, config, arguments):
|
||||
async def failing_invoke(config, arguments):
|
||||
raise ValueError("bad input")
|
||||
|
||||
svc.invoke = failing_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def rate_limited_invoke(user, config, arguments):
|
||||
async def rate_limited_invoke(config, arguments):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
svc.invoke = rate_limited_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r4"}
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
|
|
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def ok_invoke(user, config, arguments):
|
||||
async def ok_invoke(config, arguments):
|
||||
return "ok"
|
||||
|
||||
svc.invoke = ok_invoke
|
||||
|
|
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "unique-42"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
|
|
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
async def failing_invoke(workspace, name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
|
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
async def rate_limited(workspace, name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
|
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
|
|||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
flow.workspace = "default"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
|
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
|
|||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
async def capture_invoke(workspace, name, params):
|
||||
received["workspace"] = workspace
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
|
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
|
|||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
flow.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
|
|
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
result = await client.call(
|
||||
user="alice",
|
||||
config={"style": "pun"},
|
||||
arguments={"topic": "cats"},
|
||||
)
|
||||
|
|
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
|
|||
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, ToolServiceRequest)
|
||||
assert req.user == "alice"
|
||||
assert json.loads(req.config) == {"style": "pun"}
|
||||
assert json.loads(req.arguments) == {"topic": "cats"}
|
||||
|
||||
|
|
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await client.call(user="u", config={}, arguments={})
|
||||
await client.call(config={}, arguments={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_empty_config_sends_empty_json(self):
|
||||
|
|
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config=None, arguments=None)
|
||||
await client.call(config=None, arguments=None)
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.config == "{}"
|
||||
|
|
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config={}, arguments={}, timeout=30)
|
||||
await client.call(config={}, arguments={}, timeout=30)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 30
|
||||
|
|
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
assert result == "chunk1chunk2"
|
||||
|
|
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
|
|||
|
||||
with pytest.raises(RuntimeError, match="stream failed"):
|
||||
await client.call_streaming(
|
||||
user="u", config={}, arguments={},
|
||||
config={}, arguments={},
|
||||
callback=AsyncMock(),
|
||||
)
|
||||
|
||||
|
|
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
# Empty response is falsy, so callback shouldn't be called for it
|
||||
assert result == "data"
|
||||
assert received == ["data"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tenancy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiTenancy:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_propagated_to_invoke(self):
|
||||
"""User from request should reach the invoke method."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
users_seen = []
|
||||
|
||||
async def tracking(user, config, arguments):
|
||||
users_seen.append(user)
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
for u in ["tenant-a", "tenant-b", "tenant-c"]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user=u, config="{}", arguments="{}",
|
||||
)
|
||||
msg.properties.return_value = {"id": f"req-{u}"}
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_sends_user_in_request(self):
|
||||
"""ToolServiceClient.call should include user in request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="isolated-tenant", config={}, arguments={})
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.user == "isolated-tenant"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
- on_config_notify version comparison, type/workspace matching
|
||||
- fetch_and_apply_config retry logic over per-workspace fetches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
|
|
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
|
|||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
def _notify_msg(version, changes):
|
||||
"""Build a Mock config-notify message with given version and changes dict."""
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -77,9 +81,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(3, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -91,9 +93,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(5, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -105,9 +105,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
msg = _notify_msg(2, {"schema": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -121,40 +119,36 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={"key": "value"},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_called_once_with(
|
||||
"default", {"prompt": {"key": "value"}}, 2
|
||||
)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
async def test_handler_without_types_ignored_on_notify(self, processor):
|
||||
"""Handlers registered without types never fire on notifications."""
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
processor.register_config_handler(handler) # No types
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
msg = _notify_msg(2, {"whatever": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_not_called()
|
||||
# Version still advances past the notify
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
|
|
@ -168,156 +162,149 @@ class TestOnConfigNotify:
|
|||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
prompt_handler.assert_called_once_with(
|
||||
"default", {"prompt": {}}, 2
|
||||
)
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
all_handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
async def test_multi_workspace_notify_invokes_handler_per_ws(
|
||||
self, processor
|
||||
):
|
||||
"""Notify affecting multiple workspaces invokes handler once per workspace."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_config = {}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
assert handler.call_count == 2
|
||||
called_workspaces = {c.args[0] for c in handler.call_args_list}
|
||||
assert called_workspaces == {"ws1", "ws2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
side_effect=RuntimeError("Connection failed"),
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
async def test_applies_config_per_workspace(self, processor):
|
||||
"""Startup fetch invokes handler once per workspace affected."""
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {
|
||||
"ws1": {"k": "v1"},
|
||||
"ws2": {"k": "v2"},
|
||||
}, 10
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert h.call_count == 2
|
||||
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
|
||||
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
|
||||
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
async def test_handler_without_types_skipped_at_startup(self, processor):
|
||||
"""Handlers registered without types fetch nothing at startup."""
|
||||
typed = AsyncMock()
|
||||
untyped = AsyncMock()
|
||||
processor.register_config_handler(typed, types=["prompt"])
|
||||
processor.register_config_handler(untyped)
|
||||
|
||||
async def mock_fetch():
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {"default": {}}, 1
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
typed.assert_called_once()
|
||||
untyped.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
return {"default": {"k": "v"}}, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
), patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
h.assert_called_once_with(
|
||||
"default", {"prompt": {"k": "v"}}, 5
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
result = await client.query(
|
||||
vector=vector,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
timeout=30
|
||||
)
|
||||
|
|
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||
assert call_args.vector == vector
|
||||
assert call_args.limit == 10
|
||||
assert call_args.user == "test_user"
|
||||
assert call_args.collection == "test_collection"
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client.request.assert_called_once()
|
||||
call_args = client.request.call_args[0][0]
|
||||
assert call_args.limit == 20 # Default limit
|
||||
assert call_args.user == "trustgraph" # Default user
|
||||
assert call_args.collection == "default" # Default collection
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
|
|||
|
|
@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
|
|||
spec_two = MagicMock()
|
||||
processor = MagicMock(specifications=[spec_one, spec_two])
|
||||
|
||||
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
|
||||
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
|
||||
|
||||
assert flow.id == "processor-1"
|
||||
assert flow.name == "flow-a"
|
||||
assert flow.workspace == "default"
|
||||
assert flow.producer == {}
|
||||
assert flow.consumer == {}
|
||||
assert flow.parameter == {}
|
||||
|
|
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
|
|||
def test_flow_start_and_stop_visit_all_consumers():
|
||||
consumer_one = AsyncMock()
|
||||
consumer_two = AsyncMock()
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.consumer = {"one": consumer_one, "two": consumer_two}
|
||||
|
||||
asyncio.run(flow.start())
|
||||
|
|
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
|
|||
|
||||
|
||||
def test_flow_call_returns_values_in_priority_order():
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.producer["shared"] = "producer-value"
|
||||
flow.consumer["consumer-only"] = "consumer-value"
|
||||
flow.consumer["shared"] = "consumer-value"
|
||||
|
|
|
|||
|
|
@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
|||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
assert flow_name in processor.flows
|
||||
assert ("default", flow_name) in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', flow_name, processor, flow_defn
|
||||
'test-processor', flow_name, "default", processor, flow_defn
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
|
|
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
await processor.start_flow(flow_name, {'config': 'test-config'})
|
||||
await processor.start_flow("default", flow_name, {'config': 'test-config'})
|
||||
|
||||
await processor.stop_flow(flow_name)
|
||||
await processor.stop_flow("default", flow_name)
|
||||
|
||||
assert flow_name not in processor.flows
|
||||
assert ("default", flow_name) not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
await processor.stop_flow("default", 'non-existent-flow')
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert 'test-flow' in processor.flows
|
||||
assert ("default", 'test-flow') in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', 'test-flow', processor,
|
||||
'test-processor', 'test-flow', "default", processor,
|
||||
{'config': 'test-config'}
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
|
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
await processor.on_configure_flows("default", config_data1, version=1)
|
||||
|
||||
config_data2 = {
|
||||
'processor:test-processor': {
|
||||
|
|
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
await processor.on_configure_flows("default", config_data2, version=2)
|
||||
|
||||
assert 'flow1' not in processor.flows
|
||||
assert ("default", 'flow1') not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
assert 'flow2' in processor.flows
|
||||
assert ("default", 'flow2') in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ def sample_text_document():
|
|||
"""Sample document with moderate length text."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "The quick brown fox jumps over the lazy dog. " * 20
|
||||
|
|
@ -43,7 +42,6 @@ def long_text_document():
|
|||
"""Long document for testing multiple chunks."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-long",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
# Create a long text that will definitely be chunked
|
||||
|
|
@ -59,7 +57,6 @@ def unicode_text_document():
|
|||
"""Document with various unicode characters."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-unicode",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = """
|
||||
|
|
@ -84,7 +81,6 @@ def empty_text_document():
|
|||
"""Empty document for edge case testing."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-empty",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
return TextDocument(
|
||||
|
|
|
|||
|
|
@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
|
|
|
|||
|
|
@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ class TestListConfigItems:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
|
|
@ -128,7 +129,8 @@ class TestListConfigItems:
|
|||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -196,7 +198,8 @@ class TestGetConfigItem:
|
|||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -253,7 +256,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
|
|
@ -278,7 +282,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
|
|
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def knowledge_loader():
|
|||
return KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc-123",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
|
|||
loader = KnowledgeLoader(
|
||||
files=["file1.ttl", "file2.ttl"],
|
||||
flow="my-flow",
|
||||
user="user1",
|
||||
workspace="user1",
|
||||
collection="col1",
|
||||
document_id="doc1",
|
||||
url="http://example.com/",
|
||||
|
|
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
|
|||
|
||||
assert loader.files == ["file1.ttl", "file2.ttl"]
|
||||
assert loader.flow == "my-flow"
|
||||
assert loader.user == "user1"
|
||||
assert loader.workspace == "user1"
|
||||
assert loader.collection == "col1"
|
||||
assert loader.document_id == "doc1"
|
||||
assert loader.url == "http://example.com/"
|
||||
|
|
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
|
|||
# Verify Api was created with correct parameters
|
||||
mock_api_class.assert_called_once_with(
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
token="test-token",
|
||||
workspace="test-user"
|
||||
)
|
||||
|
||||
# Verify bulk client was obtained
|
||||
|
|
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
|
|||
call_args = mock_bulk.import_triples.call_args
|
||||
assert call_args[1]['flow'] == "test-flow"
|
||||
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||
assert call_args[1]['metadata']['user'] == "test-user"
|
||||
assert call_args[1]['metadata']['collection'] == "test-collection"
|
||||
|
||||
# Verify import_entity_contexts was called
|
||||
|
|
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
|
|||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'-f', 'my-flow',
|
||||
'-U', 'my-user',
|
||||
'-w', 'my-user',
|
||||
'-C', 'my-collection',
|
||||
'-u', 'http://custom.example.com/',
|
||||
'-t', 'my-token',
|
||||
|
|
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
|
|||
token='my-token',
|
||||
flow='my-flow',
|
||||
files=['file1.ttl', 'file2.ttl'],
|
||||
user='my-user',
|
||||
workspace='my-user',
|
||||
collection='my-collection'
|
||||
)
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
|
|||
# Verify defaults were used
|
||||
call_args = mock_loader_class.call_args[1]
|
||||
assert call_args['flow'] == 'default'
|
||||
assert call_args['user'] == 'trustgraph'
|
||||
assert call_args['workspace'] == 'default'
|
||||
assert call_args['collection'] == 'default'
|
||||
assert call_args['url'] == 'http://localhost:8088/'
|
||||
assert call_args['token'] is None
|
||||
|
|
@ -287,7 +287,7 @@ class TestErrorHandling:
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_set_main_structured_query_no_arguments_needed(self):
|
||||
|
|
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
|
|
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
show_main()
|
||||
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Act
|
||||
result = client.request(
|
||||
vector=vector,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
limit=10,
|
||||
timeout=300
|
||||
|
|
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.call.assert_called_once_with(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["test_chunk"]
|
||||
client.call.assert_called_once_with(
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def _make_query(
|
|||
|
||||
query = Query(
|
||||
rag=rag,
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
verbose=False,
|
||||
entity_limit=entity_limit,
|
||||
|
|
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
|
|||
assert calls[0].kwargs["p"] is None
|
||||
assert calls[0].kwargs["o"] is None
|
||||
assert calls[0].kwargs["limit"] == 15
|
||||
assert calls[0].kwargs["user"] == "test-user"
|
||||
assert calls[0].kwargs["collection"] == "test-collection"
|
||||
assert calls[0].kwargs["batch_size"] == 20
|
||||
|
||||
|
|
|
|||
|
|
@ -28,10 +28,12 @@ def mock_flow_config():
|
|||
"""Mock flow configuration."""
|
||||
mock_config = Mock()
|
||||
mock_config.flows = {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
"test-user": {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -43,7 +45,7 @@ def mock_flow_config():
|
|||
def mock_request():
|
||||
"""Mock knowledge load request."""
|
||||
request = Mock()
|
||||
request.user = "test-user"
|
||||
request.workspace = "test-user"
|
||||
request.id = "test-doc-id"
|
||||
request.collection = "test-collection"
|
||||
request.flow = "test-flow"
|
||||
|
|
@ -71,7 +73,6 @@ def sample_triples():
|
|||
return Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -90,7 +91,6 @@ def sample_graph_embeddings():
|
|||
return GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_triples_pub.send.assert_called_once()
|
||||
sent_triples = mock_triples_pub.send.call_args[0][1]
|
||||
assert sent_triples.metadata.collection == "test-collection"
|
||||
assert sent_triples.metadata.user == "test-user"
|
||||
assert sent_triples.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_ge_pub.send.assert_called_once()
|
||||
sent_ge = mock_ge_pub.send.call_args[0][1]
|
||||
assert sent_ge.metadata.collection == "test-collection"
|
||||
assert sent_ge.metadata.user == "test-user"
|
||||
assert sent_ge.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
|
||||
# Create request with None collection
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = None # Should fall back to "default"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core validates flow configuration before processing."""
|
||||
# Request with invalid flow
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
|
||||
|
|
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
|
||||
# Test missing ID
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = None # Missing
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
|
||||
"""Test that get_kg_core preserves collection field from stored data."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_list_kg_cores(self, knowledge_manager):
|
||||
"""Test listing knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
||||
|
|
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_delete_kg_core(self, knowledge_manager):
|
||||
"""Test deleting knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message with inline data
|
||||
content = b"# Document Title\nBody text content."
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
]
|
||||
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_basic(self):
|
||||
"""Test basic collection name creation"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_special_characters(self):
|
||||
"""Test collection name creation with special characters that need sanitization"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@domain.com",
|
||||
workspace="user@domain.com",
|
||||
collection="test-collection.v2",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_unicode(self):
|
||||
"""Test collection name creation with Unicode characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="测试用户",
|
||||
workspace="测试用户",
|
||||
collection="colección_española",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_spaces(self):
|
||||
"""Test collection name creation with spaces"""
|
||||
result = make_safe_collection_name(
|
||||
user="test user",
|
||||
workspace="test user",
|
||||
collection="my test collection",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
|
||||
"""Test collection name creation with multiple consecutive special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@@@domain!!!",
|
||||
workspace="user@@@domain!!!",
|
||||
collection="test---collection...v2",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
|
||||
"""Test collection name creation with leading/trailing special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="__test_user__",
|
||||
workspace="__test_user__",
|
||||
collection="@@test_collection##",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_user(self):
|
||||
"""Test collection name creation with empty user (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_collection(self):
|
||||
"""Test collection name creation with empty collection (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_both_empty(self):
|
||||
"""Test collection name creation with both user and collection empty"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_only_special_characters(self):
|
||||
"""Test collection name creation with only special characters (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="@@@!!!",
|
||||
workspace="@@@!!!",
|
||||
collection="---###",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_whitespace_only(self):
|
||||
"""Test collection name creation with whitespace-only strings"""
|
||||
result = make_safe_collection_name(
|
||||
user=" \n\t ",
|
||||
workspace=" \n\t ",
|
||||
collection=" \r\n ",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
|
||||
"""Test collection name creation with mixed valid and invalid characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123@test",
|
||||
workspace="user123@test",
|
||||
collection="coll_2023.v1",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
|
|||
long_collection = "b" * 100
|
||||
|
||||
result = make_safe_collection_name(
|
||||
user=long_user,
|
||||
workspace=long_user,
|
||||
collection=long_collection,
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_numeric_values(self):
|
||||
"""Test collection name creation with numeric user/collection values"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123",
|
||||
workspace="user123",
|
||||
collection="collection456",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_case_sensitivity(self):
|
||||
"""Test that collection name creation preserves case"""
|
||||
result = make_safe_collection_name(
|
||||
user="TestUser",
|
||||
workspace="TestUser",
|
||||
collection="TestCollection",
|
||||
prefix="Doc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ def processor():
|
|||
)
|
||||
|
||||
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
|
||||
user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self, processor):
|
||||
"""Output should carry the original metadata."""
|
||||
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
|
||||
msg = _make_chunk_message(collection="reports", doc_id="d1")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
|
|
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
|
|||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "reports"
|
||||
assert result.metadata.id == "d1"
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
|
|||
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
|
||||
|
||||
|
||||
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_message(entities, doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = EntityContexts(metadata=metadata, entities=entities)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
|
||||
msg = _make_message(entities, doc_id="doc-42", collection="main")
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
|
||||
mock_output = AsyncMock()
|
||||
|
|
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
for call in mock_output.send.call_args_list:
|
||||
result = call[0][0]
|
||||
assert result.metadata.id == "doc-42"
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
assert 'customers' in processor.schemas["default"]
|
||||
assert processor.schemas["default"]['customers'].name == 'customers'
|
||||
assert len(processor.schemas["default"]['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
|
|
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
assert processor.schemas.get("default", {}) == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
|
|
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
|
|
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
'customers': RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
|
|
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
mock_flow.workspace = "default"
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,11 +34,10 @@ def _make_defn(entity, definition):
|
|||
return {"entity": entity, "definition": definition}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -229,8 +228,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -238,7 +236,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(triples_pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -247,8 +244,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-2", root="r-2",
|
||||
user="u-2", collection="coll-2",
|
||||
"text", meta_id="c-2", root="r-2", collection="coll-2",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
|
|||
|
|
@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
|
|||
}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
"""Build a mock message wrapping a Chunk."""
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -189,8 +188,7 @@ class TestMetadataPreservation:
|
|||
rels = [_make_rel("X", "rel", "Y")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -198,7 +196,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
|
|||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
def _notify(version, changes):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
|
|
@ -47,98 +53,70 @@ class TestConfigReceiver:
|
|||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
async def test_on_config_notify_new_version_fetches_per_workspace(self):
|
||||
"""Notify with newer version fetches each affected workspace."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
msg = _notify(2, {"flow": ["ws1", "ws2"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert set(fetch_calls) == {"ws1", "ws2"}
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
"""Older-version notifies are ignored."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(3, {"flow": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
"""Notifies without flow changes advance version but skip fetch."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(2, {"prompt": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
"""on_config_notify swallows exceptions from message decode."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
|
|
@ -146,19 +124,18 @@ class TestConfigReceiver:
|
|||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
async def test_fetch_and_apply_workspace_starts_new_flows(self):
|
||||
"""fetch_and_apply_workspace starts newly-configured flows."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,36 +144,39 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" in config_receiver.flows["default"]
|
||||
assert len(start_flow_calls) == 2
|
||||
assert all(c[0] == "default" for c in start_flow_calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
|
||||
"""fetch_and_apply_workspace stops flows no longer configured."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,20 +185,22 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" not in config_receiver.flows["default"]
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
assert stop_flow_calls[0][:2] == ("default", "flow2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
async def test_fetch_and_apply_workspace_with_no_flows(self):
|
||||
"""Empty workspace config clears any local flow state."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
|
@ -231,88 +213,100 @@ class TestConfigReceiver:
|
|||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.flows.get("default", {}) == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
"""start_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler1.start_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
handler2.start_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in start_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
"""stop_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler1.stop_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
handler2.stop_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in stop_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -329,25 +323,25 @@ class TestConfigReceiver:
|
|||
mock_create_task.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
|
||||
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
"flow3": '{"name": "test_flow_3"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -358,20 +352,22 @@ class TestConfigReceiver:
|
|||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_calls.append((workspace, id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
ws_flows = config_receiver.flows["default"]
|
||||
assert "flow1" not in ws_flows
|
||||
assert "flow2" in ws_flows
|
||||
assert "flow3" in ws_flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert start_calls[0][:2] == ("default", "flow3")
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
assert stop_calls[0][:2] == ("default", "flow1")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ def _ge_response_dict():
|
|||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
|
|
@ -59,7 +58,6 @@ def _triples_response_dict():
|
|||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
|
|
@ -73,9 +71,9 @@ def _triples_response_dict():
|
|||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", user="alice"):
|
||||
def _make_request(id_="doc-1", workspace="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "user": user}
|
||||
request.query = {"id": id_, "workspace": workspace}
|
||||
return request
|
||||
|
||||
|
||||
|
|
@ -149,12 +147,8 @@ class TestCoreExportWireFormat:
|
|||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
# Metadata envelope: only id/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
|
|
@ -202,11 +196,7 @@ class TestCoreExportWireFormat:
|
|||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
|
|
@ -240,7 +230,7 @@ class TestCoreImportWireFormat:
|
|||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
|
|
@ -266,7 +256,7 @@ class TestCoreImportWireFormat:
|
|||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["user"] == "alice"
|
||||
assert req["workspace"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
|
|
@ -275,7 +265,6 @@ class TestCoreImportWireFormat:
|
|||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
|
|
@ -302,7 +291,7 @@ class TestCoreImportWireFormat:
|
|||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
|
|
@ -407,11 +396,10 @@ class TestCoreImportExportRoundTrip:
|
|||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id/user from the URL query (intentional),
|
||||
# The import side overrides id from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
assert ge["metadata"]["user"] == original["metadata"]["user"]
|
||||
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
|
|
|
|||
|
|
@ -72,10 +72,10 @@ class TestDispatcherManager:
|
|||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
await manager.start_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") in manager.flows
|
||||
assert manager.flows[("default", "flow1")] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
|
|
@ -86,11 +86,11 @@ class TestDispatcherManager:
|
|||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
manager.flows[("default", "flow1")] = flow_data
|
||||
|
||||
await manager.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
|
|
@ -275,12 +275,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -290,7 +290,7 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
|
|
@ -326,12 +326,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
|
|
@ -348,12 +348,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -404,7 +404,7 @@ class TestDispatcherManager:
|
|||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -415,14 +415,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
|
@ -435,7 +435,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -443,7 +443,7 @@ class TestDispatcherManager:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -452,23 +452,23 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
consumer="api-gateway-default-test_flow-agent-request",
|
||||
subscriber="api-gateway-default-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -479,26 +479,26 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-load": {"flow": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
|
|
@ -506,9 +506,9 @@ class TestDispatcherManager:
|
|||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -519,7 +519,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
|
|
@ -529,14 +529,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
|
|
@ -546,7 +546,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
|
|
@ -558,7 +558,7 @@ class TestDispatcherManager:
|
|||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
|
||||
|
|
@ -608,7 +608,7 @@ class TestDispatcherManager:
|
|||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -630,7 +630,7 @@ class TestDispatcherManager:
|
|||
mock_rr_dispatchers.__contains__.return_value = True
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
for _ in range(5)
|
||||
])
|
||||
|
||||
|
|
@ -638,5 +638,5 @@ class TestDispatcherManager:
|
|||
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||
)
|
||||
assert mock_dispatcher.start.call_count == 1
|
||||
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
|
||||
assert all(r == "result" for r in results)
|
||||
|
|
@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
|
|||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
|
|
|
|||
|
|
@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
|
|||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
|
|
|
|||
|
|
@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
|
|||
|
||||
# Check metadata
|
||||
assert sent_object.metadata.id == "obj-123"
|
||||
assert sent_object.metadata.user == "testuser"
|
||||
assert sent_object.metadata.collection == "testcollection"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
|
|||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,10 +29,9 @@ class Triple:
|
|||
self.o = o
|
||||
|
||||
class Metadata:
|
||||
def __init__(self, id, user, collection, root=""):
|
||||
def __init__(self, id, collection, root=""):
|
||||
self.id = id
|
||||
self.root = root
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
|
||||
class Triples:
|
||||
|
|
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
|
|||
"""Sample Triples batch object"""
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +121,6 @@ def sample_chunk():
|
|||
"""Sample text chunk for processing"""
|
||||
metadata = Metadata(
|
||||
id="test-chunk-456",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_triples, Triples)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_triples.metadata.id == sample_metadata.id
|
||||
assert sent_triples.metadata.user == sample_metadata.user
|
||||
assert sent_triples.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
|
@ -346,7 +345,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_contexts, EntityContexts)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_contexts.metadata.id == sample_metadata.id
|
||||
assert sent_contexts.metadata.user == sample_metadata.user
|
||||
assert sent_contexts.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
|
|
|||
|
|
@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
|
|||
"""Test ExtractedObject creation and properties"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-extraction-001",
|
||||
user="test_user",
|
||||
id="test-extraction-001",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
|
|||
assert extracted_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert extracted_obj.confidence == 0.95
|
||||
assert "John Doe" in extracted_obj.source_span
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
|
||||
def test_config_parsing_error_handling(self):
|
||||
"""Test configuration parsing with invalid JSON"""
|
||||
|
|
|
|||
|
|
@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
|
|||
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
|
|||
# Assert
|
||||
assert isinstance(triples_batch, Triples)
|
||||
assert triples_batch.metadata.id == "test-doc-123"
|
||||
assert triples_batch.metadata.user == "test_user"
|
||||
assert triples_batch.metadata.collection == "test_collection"
|
||||
assert len(triples_batch.triples) == 2
|
||||
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
|
|||
|
||||
|
||||
def _make_doc_metadata(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
|
||||
):
|
||||
meta = MagicMock()
|
||||
meta.id = doc_id
|
||||
meta.kind = kind
|
||||
meta.user = user
|
||||
meta.workspace = workspace
|
||||
meta.title = title
|
||||
meta.time = 1700000000
|
||||
meta.comments = ""
|
||||
|
|
@ -47,27 +47,27 @@ def _make_doc_metadata(
|
|||
|
||||
|
||||
def _make_begin_request(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice",
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice",
|
||||
total_size=10_000_000, chunk_size=0
|
||||
):
|
||||
req = MagicMock()
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
|
||||
req.total_size = total_size
|
||||
req.chunk_size = chunk_size
|
||||
return req
|
||||
|
||||
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
|
||||
req = MagicMock()
|
||||
req.upload_id = upload_id
|
||||
req.chunk_index = chunk_index
|
||||
req.user = user
|
||||
req.workspace = workspace
|
||||
req.content = base64.b64encode(content)
|
||||
return req
|
||||
|
||||
|
||||
def _make_session(
|
||||
user="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
workspace="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
total_size=10_000_000, chunks_received=None, object_id="obj-1",
|
||||
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
|
||||
):
|
||||
|
|
@ -76,11 +76,11 @@ def _make_session(
|
|||
if document_metadata is None:
|
||||
document_metadata = json.dumps({
|
||||
"id": document_id, "kind": "application/pdf",
|
||||
"user": user, "title": "Test", "time": 1700000000,
|
||||
"workspace": workspace, "title": "Test", "time": 1700000000,
|
||||
"comments": "", "tags": [],
|
||||
})
|
||||
return {
|
||||
"user": user,
|
||||
"workspace": workspace,
|
||||
"total_chunks": total_chunks,
|
||||
"chunk_size": chunk_size,
|
||||
"total_size": total_size,
|
||||
|
|
@ -259,10 +259,10 @@ class TestUploadChunk:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(user="bob")
|
||||
req = _make_upload_chunk_request(workspace="bob")
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -375,7 +375,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="Missing chunks"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -406,7 +406,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -414,12 +414,12 @@ class TestCompleteUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -439,7 +439,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.abort_upload(req)
|
||||
|
||||
|
|
@ -456,7 +456,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -464,12 +464,12 @@ class TestAbortUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -492,7 +492,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -510,7 +510,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-expired"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -527,7 +527,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -539,12 +539,12 @@ class TestGetUploadStatus:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.get_upload_status(req)
|
||||
|
|
@ -564,7 +564,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -587,7 +587,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -608,7 +608,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -630,7 +630,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 0 # Should use default 1MB
|
||||
|
||||
|
|
@ -649,7 +649,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 1000
|
||||
|
||||
|
|
@ -666,7 +666,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 512
|
||||
|
||||
|
|
@ -698,7 +698,7 @@ class TestListUploads:
|
|||
]
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
@ -713,7 +713,7 @@ class TestListUploads:
|
|||
lib.table_store.list_upload_sessions.return_value = []
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def _make_processor(tools=None):
|
|||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
processor.agents = {"default": agent}
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
|
@ -254,6 +254,7 @@ def _make_flow():
|
|||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
flow.workspace = "default"
|
||||
return flow
|
||||
|
||||
|
||||
|
|
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -344,7 +345,6 @@ class TestAgentReactDagStructure:
|
|||
|
||||
request1 = AgentRequest(
|
||||
question="What is 6x7?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
|
|
@ -433,7 +433,7 @@ class TestAgentPlanDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -480,7 +480,6 @@ class TestAgentPlanDagStructure:
|
|||
# Iteration 1: planning
|
||||
request1 = AgentRequest(
|
||||
question="Test?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
|
|
@ -537,7 +536,7 @@ class TestAgentSupervisorDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -563,7 +562,6 @@ class TestAgentSupervisorDagStructure:
|
|||
|
||||
request = AgentRequest(
|
||||
question="Research quantum computing",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||
"""Test querying document embeddings with a longer vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=3
|
||||
|
|
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the specified limit
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying document embeddings with empty vectors list"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying document embeddings with empty search results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_unicode_documents(self, processor):
|
||||
"""Test querying document embeddings with Unicode document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify Unicode content is preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
"""Test querying document embeddings with large document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify large content is preserved in ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
"""Test querying document embeddings with special characters in documents"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify special characters are preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying document embeddings with negative limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for negative limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_document_embeddings(query)
|
||||
await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||
limit=5
|
||||
|
|
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||
"""Test querying document embeddings with multiple results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
chunks = await processor.query_document_embeddings(mock_query_message)
|
||||
chunks = await processor.query_document_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
|
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is passed to query
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
|
|
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert chunks == []
|
||||
|
|
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode content is properly handled
|
||||
assert len(chunks) == 2
|
||||
|
|
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify large content is properly handled
|
||||
assert len(chunks) == 1
|
||||
|
|
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all content types are properly handled
|
||||
assert len(chunks) == 5
|
||||
|
|
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
|
|
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
processor.pinecone.Index.side_effect = Exception("Index access failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index access failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
|
|
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all queries were made
|
||||
assert mock_index.query.call_count == 3
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
|
|
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once with correct collection
|
||||
|
|
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('utf8_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
|
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('large_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
|
|
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['chunk_id']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('payload_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||
"""Test querying graph embeddings returns multiple results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_with_limit(self, processor):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with 2*limit for better deduplication
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||
"""Test that query results preserve order from the vector store"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify results are in the same order as returned by the store
|
||||
assert len(result) == 3
|
||||
|
|
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||
"""Test that results are properly limited when store returns more than requested"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=2
|
||||
|
|
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying graph embeddings with empty vectors list"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying graph embeddings with empty search results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
|
||||
"""Test querying graph embeddings with mixed URI and literal results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
|
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_graph_embeddings(query)
|
||||
await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||
"""Test querying graph embeddings with a longer vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
limit=5
|
||||
|
|
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
entities = await processor.query_graph_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
|
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is respected
|
||||
assert len(entities) == 2
|
||||
|
|
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
|
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert entities == []
|
||||
|
|
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
|
|
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should only return 2 entities (respecting limit)
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_graph_embeddings(message)
|
||||
await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
|
|
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('uri_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
await processor.query_graph_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryUserCollectionIsolation:
|
||||
class TestMemgraphQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Memgraph query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
|
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryUserCollectionIsolation:
|
||||
class TestNeo4jQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Neo4j query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify SP query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_node_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
|
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
|
|||
}
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert "customer" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
|
|||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
# Verify per-workspace schema builder was created and graphql schema built
|
||||
assert "default" in processor.schema_builders
|
||||
assert "default" in processor.graphql_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
"""Test GraphQL execution context setup"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
graphql_schema.execute.assert_called_once()
|
||||
call_args = graphql_schema.execute.call_args
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["workspace"] == "default"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
# Verify result structure
|
||||
|
|
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
|
|||
async def test_error_handling_graphql_errors(self):
|
||||
"""Test GraphQL error handling and conversion"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
|
|
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
|
|||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "default"
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
|
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
variables={},
|
||||
|
|
@ -357,7 +357,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
@ -374,7 +374,7 @@ class TestUnifiedTableQueries:
|
|||
query = call_args[0][1]
|
||||
params = call_args[0][2]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "SELECT data, source FROM test_workspace.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
|
|
@ -421,7 +421,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with correct parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
|
|
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
|
|
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
|
|
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
|
|
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
|
|
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with authentication
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
|
|
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
|
|
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# PO query pattern (predicate + object, find subjects)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# OS query pattern (object + subject, find predicates)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
|
|
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_tg_instance.reset_mock()
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value=s) if s else None,
|
||||
p=Term(type=LITERAL, value=p) if p else None,
|
||||
|
|
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=10
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify the correct method was called
|
||||
method = getattr(mock_tg_instance, expected_method)
|
||||
|
|
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# This is the query pattern that was slow with ALLOW FILTERING
|
||||
query = TriplesQueryRequest(
|
||||
user='large_dataset_user',
|
||||
collection='massive_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
|
||||
|
|
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator:
|
|||
"title": "Test Document",
|
||||
"comments": "No comments",
|
||||
"metadata": [],
|
||||
"user": "alice",
|
||||
"workspace": "alice",
|
||||
"tags": ["finance", "q4"],
|
||||
"parent-id": "doc-100",
|
||||
"document-type": "page",
|
||||
|
|
@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator:
|
|||
assert obj.time == 1710000000
|
||||
assert obj.kind == "application/pdf"
|
||||
assert obj.title == "Test Document"
|
||||
assert obj.user == "alice"
|
||||
assert obj.workspace == "alice"
|
||||
assert obj.tags == ["finance", "q4"]
|
||||
assert obj.parent_id == "doc-100"
|
||||
assert obj.document_type == "page"
|
||||
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["workspace"] == "alice"
|
||||
assert wire["parent-id"] == "doc-100"
|
||||
assert wire["document-type"] == "page"
|
||||
|
||||
|
|
@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator:
|
|||
|
||||
def test_falsy_fields_omitted_from_wire(self):
|
||||
"""Empty string fields should be omitted from wire format."""
|
||||
obj = DocumentMetadata(id="", time=0, user="")
|
||||
obj = DocumentMetadata(id="", time=0, workspace="")
|
||||
wire = self.tx.encode(obj)
|
||||
assert "id" not in wire
|
||||
assert "user" not in wire
|
||||
assert "workspace" not in wire
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator:
|
|||
"document-id": "doc-123",
|
||||
"time": 1710000000,
|
||||
"flow": "default",
|
||||
"user": "alice",
|
||||
"workspace": "alice",
|
||||
"collection": "my-collection",
|
||||
"tags": ["tag1"],
|
||||
}
|
||||
|
|
@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator:
|
|||
assert obj.id == "proc-1"
|
||||
assert obj.document_id == "doc-123"
|
||||
assert obj.flow == "default"
|
||||
assert obj.user == "alice"
|
||||
assert obj.workspace == "alice"
|
||||
assert obj.collection == "my-collection"
|
||||
assert obj.tags == ["tag1"]
|
||||
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "proc-1"
|
||||
assert wire["document-id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["workspace"] == "alice"
|
||||
assert wire["collection"] == "my-collection"
|
||||
|
||||
def test_missing_fields_use_defaults(self):
|
||||
obj = self.tx.decode({})
|
||||
assert obj.id is None
|
||||
assert obj.user is None
|
||||
assert obj.workspace is None
|
||||
assert obj.collection is None
|
||||
|
||||
def test_tags_none_omitted(self):
|
||||
|
|
@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator:
|
|||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_user_and_collection_preserved(self):
|
||||
def test_workspace_and_collection_preserved(self):
|
||||
"""Core pipeline routing fields must survive round-trip."""
|
||||
data = {"user": "bob", "collection": "research"}
|
||||
data = {"workspace": "bob", "collection": "research"}
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["user"] == "bob"
|
||||
assert wire["workspace"] == "bob"
|
||||
assert wire["collection"] == "research"
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [] # Empty vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
|
||||
# No upsert should be called
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
|
@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = None # None vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.1, 0.2, 0.3]
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.1, 0.2, 0.3]
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "alice"
|
||||
msg.metadata.collection = "docs"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.0] * 384 # 384-dim vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("alice", msg)
|
||||
|
||||
call_args = proc.qdrant.upsert.call_args
|
||||
assert "d_alice_docs_384" in call_args[1]["collection_name"]
|
||||
|
|
@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [0.1, 0.2, 0.3]
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [0.1, 0.2, 0.3]
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [] # Empty vector
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.chunk_id = "c1"
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "alice"
|
||||
msg.metadata.collection = "graphs"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.chunk_id = ""
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("alice", msg)
|
||||
|
||||
# Collection should be created with correct dimension
|
||||
proc.qdrant.create_collection.assert_called_once()
|
||||
|
|
@ -290,11 +280,10 @@ class TestCollectionValidation:
|
|||
proc.collection_exists = MagicMock(return_value=False)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "deleted-col"
|
||||
msg.chunks = [MagicMock()]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -306,9 +295,8 @@ class TestCollectionValidation:
|
|||
proc.collection_exists = MagicMock(return_value=False)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "deleted-col"
|
||||
msg.entities = [MagicMock()]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -92,14 +92,13 @@ class TestQuery:
|
|||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
|
|
@ -112,7 +111,7 @@ class TestQuery:
|
|||
# Initialize Query with custom doc_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
workspace="test_workspace",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
|
|
@ -120,7 +119,6 @@ class TestQuery:
|
|||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
|
|
@ -137,7 +135,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -162,7 +160,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -184,7 +182,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -223,7 +221,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
|
|
@ -240,7 +238,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -286,7 +283,6 @@ class TestQuery:
|
|||
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10
|
||||
)
|
||||
|
|
@ -304,7 +300,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -350,7 +345,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
user="trustgraph", # Default user
|
||||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
|
|
@ -380,7 +374,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
|
|
@ -453,7 +447,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -509,7 +503,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -558,7 +552,6 @@ class TestQuery:
|
|||
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
collection="ml_knowledge",
|
||||
doc_limit=25
|
||||
)
|
||||
|
|
@ -619,7 +612,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Unit test for DocumentRAG service parameter passing fix.
|
||||
Tests that user and collection parameters from the message are correctly
|
||||
Tests that the collection parameter from the message is correctly
|
||||
passed to the DocumentRag.query() method.
|
||||
"""
|
||||
|
||||
|
|
@ -16,13 +16,13 @@ class TestDocumentRagService:
|
|||
|
||||
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
|
||||
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
|
||||
"""
|
||||
Test that user and collection from message are passed to DocumentRag.query().
|
||||
|
||||
This is a regression test for the bug where user/collection parameters
|
||||
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
||||
instead of 'd_my_user_test_coll_1_384'.
|
||||
Test that collection from message is passed to DocumentRag.query().
|
||||
|
||||
This is a regression test for the bug where the collection parameter
|
||||
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
||||
instead of one that reflects the requested collection.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -30,17 +30,16 @@ class TestDocumentRagService:
|
|||
id="test-processor",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
|
||||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
|
||||
# Setup message with custom collection
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="test query",
|
||||
user="my_user", # Custom user (not default "trustgraph")
|
||||
collection="test_coll_1", # Custom collection (not default "default")
|
||||
doc_limit=5
|
||||
)
|
||||
|
|
@ -64,7 +63,7 @@ class TestDocumentRagService:
|
|||
# Verify: DocumentRag.query was called with correct parameters
|
||||
mock_rag_instance.query.assert_called_once_with(
|
||||
"test query",
|
||||
user="my_user", # Must be from message, not hardcoded default
|
||||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
|
|
@ -103,7 +102,6 @@ class TestDocumentRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
doc_limit=10,
|
||||
streaming=False # Non-streaming mode
|
||||
|
|
|
|||
|
|
@ -78,14 +78,12 @@ class TestQuery:
|
|||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.entity_limit == 50 # Default value
|
||||
|
|
@ -101,7 +99,6 @@ class TestQuery:
|
|||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
entity_limit=100,
|
||||
|
|
@ -112,7 +109,6 @@ class TestQuery:
|
|||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.entity_limit == 100
|
||||
|
|
@ -133,7 +129,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -156,7 +151,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -177,7 +171,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -201,7 +194,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -244,7 +236,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
entity_limit=25
|
||||
|
|
@ -269,7 +260,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -277,7 +267,7 @@ class TestQuery:
|
|||
result = await query.maybe_label("entity1")
|
||||
|
||||
assert result == "Entity One Label"
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
mock_cache.get.assert_called_once_with("test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
|
|
@ -295,7 +285,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -307,13 +296,12 @@ class TestQuery:
|
|||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
assert result == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
cache_key = "test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -330,7 +318,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -342,13 +329,12 @@ class TestQuery:
|
|||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
assert result == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
cache_key = "test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -375,7 +361,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
|
|
@ -388,15 +373,15 @@ class TestQuery:
|
|||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
|
|
@ -415,7 +400,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -435,7 +419,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
|
|
@ -455,7 +438,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
|
|
@ -493,7 +475,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
|
|
@ -601,7 +582,6 @@ class TestQuery:
|
|||
try:
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
|
|
|
|||
|
|
@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
|
|
@ -123,7 +122,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
|
|
@ -190,7 +188,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="Test query",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert
|
||||
assert "test_schema" in processor.schemas
|
||||
schema = processor.schemas["test_schema"]
|
||||
assert "test_schema" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["test_schema"]
|
||||
assert schema.name == "test_schema"
|
||||
assert schema.description == "Test schema"
|
||||
assert len(schema.fields) == 2
|
||||
|
|
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert - bad schema should be ignored
|
||||
assert "bad_schema" not in processor.schemas
|
||||
assert "bad_schema" not in processor.schemas.get("default", {})
|
||||
|
||||
def test_processor_initialization(self, mock_pulsar_client):
|
||||
"""Test processor initialization with correct specifications"""
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def service(mock_schemas):
|
|||
taskgroup=MagicMock(),
|
||||
id="test-processor"
|
||||
)
|
||||
service.schemas = mock_schemas
|
||||
service.schemas = {"default": dict(mock_schemas)}
|
||||
return service
|
||||
|
||||
|
||||
|
|
@ -109,6 +109,7 @@ def service(mock_schemas):
|
|||
def mock_flow():
|
||||
"""Create mock flow with prompt service"""
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
prompt_request_flow = AsyncMock()
|
||||
flow.return_value.request = prompt_request_flow
|
||||
return flow, prompt_request_flow
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class TestStructuredQueryProcessor:
|
|||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from New York",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
|
|
@ -110,7 +109,6 @@ class TestStructuredQueryProcessor:
|
|||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
|
|
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify insert was called once for the single chunk with its vector
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('test_workspace', mock_message)
|
||||
|
||||
# Verify insert was called once per chunk with user/collection parameters
|
||||
# Verify insert was called once per chunk with workspace/collection parameters
|
||||
expected_calls = [
|
||||
# Chunk 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'),
|
||||
# Chunk 2 - single vector
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
|
|
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for empty chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with None chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Note: Implementation passes through None chunk_ids (only skips empty string "")
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with mix of valid and empty chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_chunk = ChunkEmbeddings(
|
||||
|
|
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify valid chunks were inserted, empty string chunk was skipped
|
||||
expected_calls = [
|
||||
|
|
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each chunk has a single vector of different dimensions
|
||||
|
|
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk1, chunk2, chunk3]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension with user/collection parameters
|
||||
expected_calls = [
|
||||
|
|
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with Unicode content in chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode chunk_id was stored correctly with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with long chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a long chunk_id
|
||||
|
|
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify long chunk_id was inserted with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with whitespace-only chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify whitespace content was inserted (not filtered out) with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
('test@domain.com', 'test-collection.v1'),
|
||||
]
|
||||
|
||||
for user, collection in test_cases:
|
||||
for workspace, collection in test_cases:
|
||||
processor.vecstore.reset_mock() # Reset mock for each test case
|
||||
|
||||
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = user
|
||||
message.metadata.collection = collection
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Test content",
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify insert was called with the correct user/collection
|
||||
|
||||
await processor.store_document_embeddings(workspace, message)
|
||||
|
||||
# Verify insert was called with the correct workspace/collection
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Test content", user, collection
|
||||
[0.1, 0.2, 0.3], "Test content", workspace, collection
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
# Store embeddings for user1/collection1
|
||||
message1 = MagicMock()
|
||||
message1.metadata = MagicMock()
|
||||
message1.metadata.user = 'user1'
|
||||
message1.metadata.collection = 'collection1'
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk_id="User1 content",
|
||||
|
|
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
# Store embeddings for user2/collection2
|
||||
message2 = MagicMock()
|
||||
message2.metadata = MagicMock()
|
||||
message2.metadata.user = 'user2'
|
||||
message2.metadata.collection = 'collection2'
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk_id="User2 content",
|
||||
|
|
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message2.chunks = [chunk2]
|
||||
|
||||
await processor.store_document_embeddings(message1)
|
||||
await processor.store_document_embeddings(message2)
|
||||
await processor.store_document_embeddings('user1', message1)
|
||||
await processor.store_document_embeddings('user2', message2)
|
||||
|
||||
# Verify both calls were made with correct parameters
|
||||
expected_calls = [
|
||||
|
|
@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with special characters in user/collection names"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'user@domain.com' # Email-like user
|
||||
message.metadata.collection = 'test-collection.v1' # Collection with special chars
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Special chars test",
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors)
|
||||
|
||||
await processor.store_document_embeddings('user@domain.com', message)
|
||||
|
||||
# Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
|
|
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for None chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with chunk that decodes to empty string"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty decoded chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each chunk has a single vector of different dimensions
|
||||
|
|
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode content was properly decoded and stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
|
|
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
|
|
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify large content was stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
|
|
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per chunk)
|
||||
|
|
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple chunks, each having a single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('vector_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per chunk)
|
||||
|
|
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with empty chunk_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
|
|
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunk_ids
|
||||
|
|
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('new_user', mock_message)
|
||||
|
||||
# Assert - collection should be lazily created
|
||||
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
|
||||
|
|
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert - should propagate the creation error
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
|
|
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
await processor.store_document_embeddings('cache_user', mock_message1)
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
|
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
|
|
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
await processor.store_document_embeddings('cache_user', mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
|
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with chunks of different dimensions
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of DIFFERENT collections for each dimension
|
||||
|
|
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with URI-style chunk_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'uri_user'
|
||||
mock_message.metadata.collection = 'uri_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('uri_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify the chunk_id was stored correctly
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entities with embeddings
|
||||
|
|
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify insert was called once with the full vector
|
||||
processor.vecstore.insert.assert_called_once()
|
||||
|
|
@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('test_workspace', mock_message)
|
||||
|
||||
# Verify insert was called once per entity with user/collection parameters
|
||||
# Verify insert was called once per entity with workspace/collection parameters
|
||||
expected_calls = [
|
||||
# Entity 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'),
|
||||
# Entity 2 - single vector
|
||||
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
|
|
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for empty entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for None entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with mix of valid and invalid entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
|
|
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify only valid entity was inserted with user/collection/chunk_id parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each entity has a single vector of different dimensions
|
||||
|
|
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity1, entity2, entity3]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
|
|
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for both URI and literal entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
|
|
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify both entities were inserted
|
||||
expected_calls = [
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entity embeddings (each entity has a single vector)
|
||||
|
|
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for None entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each entity has a single vector of different dimensions
|
||||
|
|
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
|
|
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
|
|
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
|
|
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
|
|
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per entity)
|
||||
|
|
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with three entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
|
|
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('vector_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per entity)
|
||||
|
|
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with empty entity value
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
|
|
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty entities
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Tests for Memgraph user/collection isolation in storage service
|
||||
Tests for Memgraph workspace/collection isolation in storage service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch
|
|||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
|
||||
|
||||
class TestMemgraphUserCollectionIsolation:
|
||||
"""Test cases for Memgraph storage service with user/collection isolation"""
|
||||
class TestMemgraphWorkspaceCollectionIsolation:
|
||||
"""Test cases for Memgraph storage service with workspace/collection isolation"""
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
|
||||
"""Test that storage creates both legacy and user/collection indexes"""
|
||||
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that storage creates both legacy and workspace/collection indexes"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total)
|
||||
|
||||
# 4 legacy + 4 workspace/collection = 8 total
|
||||
assert mock_session.run.call_count == 8
|
||||
|
||||
# Check some specific index creation calls
|
||||
|
||||
expected_calls = [
|
||||
"CREATE INDEX ON :Node",
|
||||
"CREATE INDEX ON :Node(uri)",
|
||||
"CREATE INDEX ON :Literal",
|
||||
"CREATE INDEX ON :Literal(value)",
|
||||
"CREATE INDEX ON :Node(user)",
|
||||
"CREATE INDEX ON :Node(workspace)",
|
||||
"CREATE INDEX ON :Node(collection)",
|
||||
"CREATE INDEX ON :Literal(user)",
|
||||
"CREATE INDEX ON :Literal(workspace)",
|
||||
"CREATE INDEX ON :Literal(collection)"
|
||||
]
|
||||
|
||||
|
||||
for expected_call in expected_calls:
|
||||
mock_session.run.assert_any_call(expected_call)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_user_collection(self, mock_graph_db):
|
||||
"""Test that store_triples includes user/collection in all operations"""
|
||||
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that store_triples includes workspace/collection in all operations"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Create mock triple with URI object
|
||||
from trustgraph.schema import IRI
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = IRI
|
||||
triple.o.iri = "http://example.com/object"
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Verify user/collection parameters were passed to all operations
|
||||
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
|
||||
# create_node (subject), create_node (object), relate_node = 3 calls
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
|
||||
# Check that user and collection were included in all calls
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert 'user' in call_kwargs
|
||||
assert 'collection' in call_kwargs
|
||||
assert call_kwargs['user'] == "test_user"
|
||||
assert call_kwargs['collection'] == "test_collection"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
assert kwargs['workspace'] == "test_workspace"
|
||||
assert kwargs['collection'] == "test_collection"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that defaults are used when user/collection not provided in metadata"""
|
||||
async def test_store_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that default collection is used when not provided in metadata"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Create mock triple
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "literal_value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message without user/collection metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = None
|
||||
mock_message.metadata.collection = None
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("default", mock_message)
|
||||
|
||||
# Verify defaults were used
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert call_kwargs['user'] == "default"
|
||||
assert call_kwargs['collection'] == "default"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
assert kwargs['workspace'] == "default"
|
||||
assert kwargs['collection'] == "default"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_node_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that create_node includes user/collection properties"""
|
||||
def test_create_node_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that create_node includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
processor.create_node("http://example.com/node", "test_user", "test_collection")
|
||||
|
||||
|
||||
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/node",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_literal_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that create_literal includes user/collection properties"""
|
||||
def test_create_literal_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that create_literal includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
processor.create_literal("test_value", "test_user", "test_collection")
|
||||
|
||||
|
||||
processor.create_literal("test_value", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_relate_node_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that relate_node includes user/collection properties"""
|
||||
def test_relate_node_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that relate_node includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
processor.relate_node(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/object",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="http://example.com/object",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_relate_literal_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that relate_literal includes user/collection properties"""
|
||||
def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that relate_literal includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
processor.relate_literal(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"literal_value",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal_value",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
|
@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation:
|
|||
def test_add_args_includes_memgraph_parameters(self):
|
||||
"""Test that add_args properly configures Memgraph-specific parameters"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added with Memgraph defaults
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
|
|
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
|
|||
assert args.database == 'memgraph'
|
||||
|
||||
|
||||
class TestMemgraphUserCollectionRegression:
|
||||
"""Regression tests to ensure user/collection isolation prevents data leakage"""
|
||||
class TestMemgraphWorkspaceCollectionRegression:
|
||||
"""Regression tests to ensure workspace/collection isolation prevents data leakage"""
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
|
||||
"""Regression test: Ensure users cannot access each other's data"""
|
||||
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
|
||||
"""Regression test: Ensure workspaces cannot access each other's data"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Store data for user1
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "user1_data"
|
||||
triple.o.is_uri = False
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "ws1_data"
|
||||
|
||||
message_user1 = MagicMock()
|
||||
message_user1.triples = [triple]
|
||||
message_user1.metadata.user = "user1"
|
||||
message_user1.metadata.collection = "collection1"
|
||||
message_ws1 = MagicMock()
|
||||
message_ws1.triples = [triple]
|
||||
message_ws1.metadata.collection = "collection1"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
|
||||
# Verify that all storage operations included user1/collection1 parameters
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
if 'user' in call_kwargs:
|
||||
assert call_kwargs['user'] == "user1"
|
||||
assert call_kwargs['collection'] == "collection1"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
if 'workspace' in kwargs:
|
||||
assert kwargs['workspace'] == "workspace1"
|
||||
assert kwargs['collection'] == "collection1"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_same_uri_different_users(self, mock_graph_db):
|
||||
"""Regression test: Same URI can exist for different users without conflict"""
|
||||
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
|
||||
"""Regression test: Same URI can exist in different workspaces without conflict"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Same URI for different users should create separate nodes
|
||||
processor.create_node("http://example.com/same-uri", "user1", "collection1")
|
||||
processor.create_node("http://example.com/same-uri", "user2", "collection2")
|
||||
|
||||
# Verify both calls were made with different user/collection parameters
|
||||
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
|
||||
|
||||
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1]
|
||||
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1]
|
||||
|
||||
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1"
|
||||
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2"
|
||||
|
||||
# Both should have the same URI but different user/collection
|
||||
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"
|
||||
|
||||
processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
|
||||
processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list[-2:]
|
||||
|
||||
k1 = calls[0].kwargs
|
||||
k2 = calls[1].kwargs
|
||||
|
||||
assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
|
||||
assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
|
||||
|
||||
assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Tests for Neo4j user/collection isolation in triples storage and query
|
||||
Tests for Neo4j workspace/collection isolation in triples storage and query.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
|
|||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
class TestNeo4jUserCollectionIsolation:
|
||||
"""Test cases for Neo4j user/collection isolation functionality"""
|
||||
class TestNeo4jWorkspaceCollectionIsolation:
|
||||
"""Test cases for Neo4j workspace/collection isolation functionality"""
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
|
||||
"""Test that storage service creates compound indexes for user/collection"""
|
||||
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that storage service creates compound indexes for workspace/collection"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify both legacy and new compound indexes are created
|
||||
|
||||
expected_indexes = [
|
||||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
|
||||
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
|
||||
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
|
||||
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
|
||||
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
|
||||
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
|
||||
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
|
||||
]
|
||||
|
||||
# Check that all expected indexes were created
|
||||
|
||||
for expected_query in expected_indexes:
|
||||
mock_session.run.assert_any_call(expected_query)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_user_collection(self, mock_graph_db):
|
||||
"""Test that triples are stored with user/collection properties"""
|
||||
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that triples are stored with workspace/collection properties"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test message with user/collection metadata
|
||||
metadata = Metadata(
|
||||
id="test-id",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
metadata = Metadata(id="test-id", collection="test_collection")
|
||||
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal_value")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
metadata=metadata,
|
||||
triples=[triple]
|
||||
)
|
||||
|
||||
# Mock execute_query to return summaries
|
||||
|
||||
message = Triples(metadata=metadata, triples=[triple])
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify nodes and relationships were created with user/collection properties
|
||||
await processor.store_triples("test_workspace", message)
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/subject",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
),
|
||||
call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="literal_value",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
),
|
||||
call(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal_value",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
for expected_call in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that default user/collection are used when not provided"""
|
||||
async def test_store_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that default collection is used when not provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test message without user/collection
|
||||
|
||||
metadata = Metadata(id="test-id")
|
||||
|
||||
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=IRI, iri="http://example.com/object")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
metadata=metadata,
|
||||
triples=[triple]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
message = Triples(metadata=metadata, triples=[triple])
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify defaults were used
|
||||
await processor.store_triples("default", message)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/subject",
|
||||
user="default",
|
||||
workspace="default",
|
||||
collection="default",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_filters_by_user_collection(self, mock_graph_db):
|
||||
"""Test that query service filters results by user/collection"""
|
||||
async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
|
||||
"""Test that query service filters results by workspace/collection"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create test query
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None
|
||||
)
|
||||
|
||||
# Mock query results
|
||||
|
||||
mock_records = [
|
||||
MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
|
||||
MagicMock(data=lambda: {"dest": "literal_value"})
|
||||
]
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify queries include user/collection filters
|
||||
|
||||
await processor.query_triples("test_workspace", query)
|
||||
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest"
|
||||
)
|
||||
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest"
|
||||
)
|
||||
|
||||
# Check that queries were executed with user/collection parameters
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
expected_literal_query in str(call) and
|
||||
"user='test_user'" in str(call) and
|
||||
"collection='test_collection'" in str(call)
|
||||
for call in calls
|
||||
expected_literal_query in str(c) and
|
||||
"workspace='test_workspace'" in str(c) and
|
||||
"collection='test_collection'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that query service uses defaults when user/collection not provided"""
|
||||
async def test_query_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that query service uses default collection when not provided"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create test query without user/collection
|
||||
query = TriplesQueryRequest(
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
)
|
||||
|
||||
# Mock empty results
|
||||
|
||||
query = TriplesQueryRequest(s=None, p=None, o=None)
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify defaults were used in queries
|
||||
|
||||
await processor.query_triples("default", query)
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
"user='default'" in str(call) and "collection='default'" in str(call)
|
||||
for call in calls
|
||||
"workspace='default'" in str(c) and "collection='default'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_isolation_between_users(self, mock_graph_db):
|
||||
"""Test that data from different users is properly isolated"""
|
||||
async def test_data_isolation_between_workspaces(self, mock_graph_db):
|
||||
"""Test that data from different workspaces is properly isolated"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create messages for different users
|
||||
message_user1 = Triples(
|
||||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
|
||||
message_ws1 = Triples(
|
||||
metadata=Metadata(collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/user1/subject"),
|
||||
s=Term(type=IRI, iri="http://example.com/ws1/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user1_data")
|
||||
o=Term(type=LITERAL, value="ws1_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
message_user2 = Triples(
|
||||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
|
||||
message_ws2 = Triples(
|
||||
metadata=Metadata(collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/user2/subject"),
|
||||
s=Term(type=IRI, iri="http://example.com/ws2/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user2_data")
|
||||
o=Term(type=LITERAL, value="ws2_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify user1 data was stored with user1/coll1
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
await processor.store_triples("workspace2", message_ws2)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value="user1_data",
|
||||
user="user1",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="ws1_data",
|
||||
workspace="workspace1",
|
||||
collection="coll1",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify user2 data was stored with user2/coll2
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value="user2_data",
|
||||
user="user2",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="ws2_data",
|
||||
workspace="workspace2",
|
||||
collection="coll2",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_respects_user_collection(self, mock_graph_db):
|
||||
"""Test that wildcard queries still filter by user/collection"""
|
||||
async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
|
||||
"""Test that wildcard queries still filter by workspace/collection"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create wildcard query (all nulls) with user/collection
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
s=None, p=None, o=None,
|
||||
)
|
||||
|
||||
# Mock results
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify wildcard queries include user/collection filters
|
||||
|
||||
await processor.query_triples("test_workspace", query)
|
||||
|
||||
wildcard_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
|
||||
)
|
||||
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
wildcard_query in str(call) and
|
||||
"user='test_user'" in str(call) and
|
||||
"collection='test_collection'" in str(call)
|
||||
for call in calls
|
||||
wildcard_query in str(c) and
|
||||
"workspace='test_workspace'" in str(c) and
|
||||
"collection='test_collection'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
def test_add_args_includes_neo4j_parameters(self):
|
||||
"""Test that add_args includes Neo4j-specific parameters"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
StorageProcessor.add_args(parser)
|
||||
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert hasattr(args, 'username')
|
||||
assert hasattr(args, 'password')
|
||||
assert hasattr(args, 'database')
|
||||
|
||||
# Check defaults
|
||||
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert args.username == 'neo4j'
|
||||
assert args.password == 'password'
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
|
||||
class TestNeo4jUserCollectionRegression:
|
||||
"""Regression tests to ensure user/collection isolation prevents data leaks"""
|
||||
|
||||
class TestNeo4jWorkspaceCollectionRegression:
|
||||
"""Regression tests to ensure workspace/collection isolation prevents data leaks"""
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
|
||||
"""
|
||||
Regression test: Ensure user1 cannot access user2's data
|
||||
|
||||
This test guards against the bug where all users shared the same
|
||||
Neo4j graph space, causing data contamination between users.
|
||||
Regression test: Ensure workspace1 cannot access workspace2's data.
|
||||
|
||||
Guards against a bug where all data shared the same Neo4j graph
|
||||
space, causing data contamination between workspaces.
|
||||
"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# User1 queries for all triples
|
||||
query_user1 = TriplesQueryRequest(
|
||||
user="user1",
|
||||
|
||||
query_ws1 = TriplesQueryRequest(
|
||||
collection="collection1",
|
||||
s=None, p=None, o=None
|
||||
)
|
||||
|
||||
# Mock that the database has data but none matching user1/collection1
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query_user1)
|
||||
|
||||
# Verify empty results (user1 cannot see other users' data)
|
||||
|
||||
result = await processor.query_triples("workspace1", query_ws1)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
# Verify the query included user/collection filters
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
for call in calls:
|
||||
query_str = str(call)
|
||||
for c in calls:
|
||||
query_str = str(c)
|
||||
if "MATCH" in query_str:
|
||||
assert "user: $user" in query_str or "user='user1'" in query_str
|
||||
assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
|
||||
assert "collection: $collection" in query_str or "collection='collection1'" in query_str
|
||||
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_same_uri_different_users(self, mock_graph_db):
|
||||
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
|
||||
"""
|
||||
Regression test: Same URI in different user contexts should create separate nodes
|
||||
|
||||
This ensures that http://example.com/entity for user1 is completely separate
|
||||
from http://example.com/entity for user2.
|
||||
Regression test: Same URI in different workspace contexts should create separate nodes.
|
||||
|
||||
Ensures http://example.com/entity in workspace1 is completely
|
||||
separate from the same URI in workspace2.
|
||||
"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Same URI for different users
|
||||
|
||||
shared_uri = "http://example.com/shared_entity"
|
||||
|
||||
message_user1 = Triples(
|
||||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
|
||||
message_ws1 = Triples(
|
||||
metadata=Metadata(collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user1_value")
|
||||
o=Term(type=LITERAL, value="ws1_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
message_user2 = Triples(
|
||||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
|
||||
message_ws2 = Triples(
|
||||
metadata=Metadata(collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user2_value")
|
||||
o=Term(type=LITERAL, value="ws2_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify two separate nodes were created with same URI but different user/collection
|
||||
user1_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
await processor.store_triples("workspace2", message_ws2)
|
||||
|
||||
ws1_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri=shared_uri,
|
||||
user="user1",
|
||||
workspace="workspace1",
|
||||
collection="coll1",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
user2_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
|
||||
ws2_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri=shared_uri,
|
||||
user="user2",
|
||||
workspace="workspace2",
|
||||
collection="coll2",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True)
|
||||
|
||||
mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)
|
||||
|
|
@ -1,3 +1,12 @@
|
|||
|
||||
def _flow_mock(workspace):
|
||||
"""Build a mock flow object that is callable and exposes .workspace."""
|
||||
from unittest.mock import MagicMock
|
||||
f = MagicMock()
|
||||
f.workspace = workspace
|
||||
return f
|
||||
|
||||
|
||||
"""
|
||||
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||
|
|
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
collection_name = processor.get_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="customer_data",
|
||||
dimension=384
|
||||
)
|
||||
|
||||
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||
assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||
|
|
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
# Create embeddings message
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
|
|
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should be called once for the single embedding
|
||||
assert mock_qdrant_instance.upsert.call_count == 1
|
||||
|
|
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should not call upsert for empty vectors
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
|
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should not call upsert for unknown collection
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
|
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock collections list
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||
mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||
mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
|
||||
mock_coll3 = MagicMock()
|
||||
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||
mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||
|
|
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_user', 'test_collection')
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
# Should delete only the matching collections
|
||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||
mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||
mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||
|
|
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
await processor.delete_collection_schema(
|
||||
'test_user', 'test_collection', 'customers'
|
||||
'test_workspace', 'test_collection', 'customers'
|
||||
)
|
||||
|
||||
# Should only delete the customers schema collection
|
||||
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||
assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
|
|
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
|
|||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -165,16 +176,18 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -191,7 +204,6 @@ class TestRowsCassandraStorageLogic:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -205,7 +217,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert was executed
|
||||
mock_async_execute.assert_called()
|
||||
|
|
@ -214,7 +226,7 @@ class TestRowsCassandraStorageLogic:
|
|||
values = insert_call[0][2]
|
||||
|
||||
# Verify using unified rows table
|
||||
assert "INSERT INTO test_user.rows" in insert_cql
|
||||
assert "INSERT INTO default.rows" in insert_cql
|
||||
|
||||
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||
assert values[0] == "test_collection" # collection
|
||||
|
|
@ -230,16 +242,18 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -255,7 +269,6 @@ class TestRowsCassandraStorageLogic:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
schema_name="multi_index_schema",
|
||||
|
|
@ -267,7 +280,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -290,15 +303,17 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -315,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
|
|
@ -331,7 +345,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -349,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
"default": {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -369,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
|
|
@ -381,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -446,19 +461,21 @@ class TestPartitionRegistration:
|
|||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
|
@ -473,7 +490,7 @@ class TestPartitionRegistration:
|
|||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph was called with auth parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user2', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph was called without auth parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
|
|
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
|
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
|
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
|
|
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
|
|||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
|
|
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
|
|
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance uses legacy mode
|
||||
assert mock_tg_instance is not None
|
||||
|
|
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance is in optimized mode
|
||||
assert mock_tg_instance is not None
|
||||
|
|
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
|
|
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri, 'test_user', 'test_collection')
|
||||
processor.create_node(test_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value, 'test_user', 'test_collection')
|
||||
processor.create_literal(test_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"value": test_value,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": dest_uri,
|
||||
"uri": pred_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": literal_value,
|
||||
"uri": pred_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing triple with URI object"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
|
|
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create object node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
|
|
@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create literal object
|
||||
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
|
||||
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
|
|
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing multiple triples"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
|
|
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing empty triples list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
|
|
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
|
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing triples with mixed URI and literal objects"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
|
|
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri, 'test_user', 'test_collection')
|
||||
processor.create_node(test_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value, 'test_user', 'test_collection')
|
||||
processor.create_literal(test_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"value": test_value,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue