mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 17:39:39 +02:00
Merge remote-tracking branch 'origin/master' into ts-port
This commit is contained in:
commit
f8252ecd54
1038 changed files with 253274 additions and 8466 deletions
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
# import asyncio
|
||||
# import tracemalloc
|
||||
# import warnings
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Uncomment the lines below to enable asyncio debug mode and tracemalloc
|
||||
|
|
@ -33,19 +34,14 @@ def mock_loki_handler(session_mocker=None):
|
|||
# Create a mock LokiHandler that does nothing
|
||||
original_loki_handler = logging_loki.LokiHandler
|
||||
|
||||
class MockLokiHandler:
|
||||
class MockLokiHandler(logging.Handler):
|
||||
"""Mock LokiHandler that doesn't make network calls."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def emit(self, record):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
return
|
||||
|
||||
# Replace the real LokiHandler with our mock
|
||||
logging_loki.LokiHandler = MockLokiHandler
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ def sample_message_data():
|
|||
},
|
||||
"DocumentRagQuery": {
|
||||
"query": "What is artificial intelligence?",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection",
|
||||
"doc_limit": 10
|
||||
},
|
||||
|
|
@ -87,7 +86,7 @@ def sample_message_data():
|
|||
"history": []
|
||||
},
|
||||
"AgentResponse": {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "Machine learning is a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
|
|
@ -95,7 +94,6 @@ def sample_message_data():
|
|||
},
|
||||
"Metadata": {
|
||||
"id": "test-doc-123",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection"
|
||||
},
|
||||
"Term": {
|
||||
|
|
@ -130,9 +128,8 @@ def invalid_message_data():
|
|||
{}, # Missing required fields
|
||||
],
|
||||
"DocumentRagQuery": [
|
||||
{"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query
|
||||
{"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user
|
||||
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
|
||||
{"query": None, "collection": "test", "doc_limit": 10}, # Invalid query
|
||||
{"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
|
||||
{"query": "test"}, # Missing required fields
|
||||
],
|
||||
"Term": [
|
||||
|
|
|
|||
|
|
@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
|
||||
def test_request_schema_fields(self):
|
||||
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
||||
# Create a request
|
||||
request = DocumentEmbeddingsRequest(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify all expected fields exist
|
||||
assert hasattr(request, 'vector')
|
||||
assert hasattr(request, 'limit')
|
||||
assert hasattr(request, 'user')
|
||||
assert hasattr(request, 'collection')
|
||||
|
||||
# Verify field values
|
||||
assert request.vector == [0.1, 0.2, 0.3]
|
||||
assert request.limit == 10
|
||||
assert request.user == "test_user"
|
||||
assert request.collection == "test_collection"
|
||||
|
||||
def test_request_translator_decode(self):
|
||||
|
|
@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
data = {
|
||||
"vector": [0.1, 0.2, 0.3, 0.4],
|
||||
"limit": 5,
|
||||
"user": "custom_user",
|
||||
"collection": "custom_collection"
|
||||
}
|
||||
|
||||
|
|
@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||
assert result.limit == 5
|
||||
assert result.user == "custom_user"
|
||||
assert result.collection == "custom_collection"
|
||||
|
||||
def test_request_translator_decode_with_defaults(self):
|
||||
|
|
@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
|
||||
data = {
|
||||
"vector": [0.1, 0.2]
|
||||
# No limit, user, or collection provided
|
||||
# No limit or collection provided
|
||||
}
|
||||
|
||||
result = translator.decode(data)
|
||||
|
|
@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2]
|
||||
assert result.limit == 10 # Default
|
||||
assert result.user == "trustgraph" # Default
|
||||
assert result.collection == "default" # Default
|
||||
|
||||
def test_request_translator_encode(self):
|
||||
|
|
@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
request = DocumentEmbeddingsRequest(
|
||||
vector=[0.5, 0.6],
|
||||
limit=20,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert isinstance(result, dict)
|
||||
assert result["vector"] == [0.5, 0.6]
|
||||
assert result["limit"] == 20
|
||||
assert result["user"] == "test_user"
|
||||
assert result["collection"] == "test_collection"
|
||||
|
||||
|
||||
|
|
@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
request_data = {
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"limit": 5,
|
||||
"user": "test_user",
|
||||
"collection": "test_collection"
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts:
|
|||
# Test required fields
|
||||
query = DocumentRagQuery(**query_data)
|
||||
assert hasattr(query, 'query')
|
||||
assert hasattr(query, 'user')
|
||||
assert hasattr(query, 'collection')
|
||||
assert hasattr(query, 'doc_limit')
|
||||
|
||||
|
|
@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts:
|
|||
# Test valid query
|
||||
valid_query = DocumentRagQuery(
|
||||
query="What is AI?",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5
|
||||
)
|
||||
assert valid_query.query == "What is AI?"
|
||||
assert valid_query.user == "test_user"
|
||||
assert valid_query.collection == "test_collection"
|
||||
assert valid_query.doc_limit == 5
|
||||
|
||||
|
|
@ -212,7 +209,7 @@ class TestAgentMessageContracts:
|
|||
|
||||
# Test required fields
|
||||
response = AgentResponse(**response_data)
|
||||
assert hasattr(response, 'chunk_type')
|
||||
assert hasattr(response, 'message_type')
|
||||
assert hasattr(response, 'content')
|
||||
assert hasattr(response, 'end_of_message')
|
||||
assert hasattr(response, 'end_of_dialog')
|
||||
|
|
@ -400,7 +397,6 @@ class TestMetadataMessageContracts:
|
|||
|
||||
metadata = Metadata(**metadata_data)
|
||||
assert metadata.id == "test-doc-123"
|
||||
assert metadata.user == "test_user"
|
||||
assert metadata.collection == "test_collection"
|
||||
|
||||
def test_error_schema_contract(self):
|
||||
|
|
@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts:
|
|||
required_fields = {
|
||||
"TextCompletionRequest": ["system", "prompt"],
|
||||
"TextCompletionResponse": ["error", "response", "model"],
|
||||
"DocumentRagQuery": ["query", "user", "collection"],
|
||||
"DocumentRagQuery": ["query", "collection"],
|
||||
"DocumentRagResponse": ["error", "response"],
|
||||
"AgentRequest": ["question", "history"],
|
||||
"AgentResponse": ["error"],
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts:
|
|||
def test_agent_request_orchestration_fields_roundtrip(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
|
|
@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts:
|
|||
def test_agent_request_orchestration_fields_default_empty(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
)
|
||||
|
||||
assert req.correlation_id == ""
|
||||
|
|
@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract:
|
|||
)
|
||||
req = AgentRequest(
|
||||
question="goal",
|
||||
user="testuser",
|
||||
correlation_id="corr-123",
|
||||
history=[step],
|
||||
)
|
||||
|
|
@ -126,7 +123,6 @@ class TestSynthesisStepContract:
|
|||
|
||||
req = AgentRequest(
|
||||
question="Original question",
|
||||
user="testuser",
|
||||
pattern="supervisor",
|
||||
correlation_id="",
|
||||
session_id="parent-sess",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class TestRowsCassandraContracts:
|
|||
# Create test object with all required fields
|
||||
test_metadata = Metadata(
|
||||
id="test-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -47,7 +46,6 @@ class TestRowsCassandraContracts:
|
|||
|
||||
# Verify metadata structure
|
||||
assert hasattr(test_object.metadata, 'id')
|
||||
assert hasattr(test_object.metadata, 'user')
|
||||
assert hasattr(test_object.metadata, 'collection')
|
||||
|
||||
# Verify types
|
||||
|
|
@ -150,7 +148,6 @@ class TestRowsCassandraContracts:
|
|||
original = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -168,7 +165,6 @@ class TestRowsCassandraContracts:
|
|||
|
||||
# Verify round-trip
|
||||
assert decoded.metadata.id == original.metadata.id
|
||||
assert decoded.metadata.user == original.metadata.user
|
||||
assert decoded.metadata.collection == original.metadata.collection
|
||||
assert decoded.schema_name == original.schema_name
|
||||
assert decoded.values == original.values
|
||||
|
|
@ -228,8 +224,7 @@ class TestRowsCassandraContracts:
|
|||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="meta-001",
|
||||
user="user123", # -> keyspace
|
||||
id="meta-001", # -> keyspace
|
||||
collection="coll456", # -> partition key
|
||||
),
|
||||
schema_name="table789", # -> table name
|
||||
|
|
@ -242,7 +237,6 @@ class TestRowsCassandraContracts:
|
|||
# - metadata.user -> Cassandra keyspace
|
||||
# - schema_name -> Cassandra table
|
||||
# - metadata.collection -> Part of primary key
|
||||
assert test_obj.metadata.user # Required for keyspace
|
||||
assert test_obj.schema_name # Required for table
|
||||
assert test_obj.metadata.collection # Required for partition key
|
||||
|
||||
|
|
@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch:
|
|||
# Create test object with multiple values in batch
|
||||
test_metadata = Metadata(
|
||||
id="batch-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch:
|
|||
"""Test empty batch ExtractedObject contract"""
|
||||
test_metadata = Metadata(
|
||||
id="empty-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch:
|
|||
"""Test single-item batch (backward compatibility) contract"""
|
||||
test_metadata = Metadata(
|
||||
id="single-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch:
|
|||
original = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch:
|
|||
|
||||
# Verify round-trip for batch
|
||||
assert decoded.metadata.id == original.metadata.id
|
||||
assert decoded.metadata.user == original.metadata.user
|
||||
assert decoded.metadata.collection == original.metadata.collection
|
||||
assert decoded.schema_name == original.schema_name
|
||||
assert len(decoded.values) == len(original.values)
|
||||
|
|
@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch:
|
|||
# 3. Be stored in the same keyspace (user)
|
||||
|
||||
test_metadata = Metadata(
|
||||
id="partition-test-001",
|
||||
user="consistent_user", # Same keyspace
|
||||
id="partition-test-001", # Same keyspace
|
||||
collection="consistent_collection", # Same partition
|
||||
)
|
||||
|
||||
|
|
@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch:
|
|||
)
|
||||
|
||||
# Verify consistency contract
|
||||
assert batch_object.metadata.user # Must have user for keyspace
|
||||
assert batch_object.metadata.collection # Must have collection for partition key
|
||||
|
||||
# Verify unique primary keys in batch
|
||||
|
|
|
|||
|
|
@ -21,29 +21,25 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test RowsQueryRequest schema structure and required fields"""
|
||||
# Create test request with all required fields
|
||||
test_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name email } }',
|
||||
variables={"status": "active", "limit": "10"},
|
||||
operation_name="GetCustomers"
|
||||
)
|
||||
|
||||
|
||||
# Verify all required fields are present
|
||||
assert hasattr(test_request, 'user')
|
||||
assert hasattr(test_request, 'collection')
|
||||
assert hasattr(test_request, 'collection')
|
||||
assert hasattr(test_request, 'query')
|
||||
assert hasattr(test_request, 'variables')
|
||||
assert hasattr(test_request, 'operation_name')
|
||||
|
||||
|
||||
# Verify field types
|
||||
assert isinstance(test_request.user, str)
|
||||
assert isinstance(test_request.collection, str)
|
||||
assert isinstance(test_request.query, str)
|
||||
assert isinstance(test_request.variables, dict)
|
||||
assert isinstance(test_request.operation_name, str)
|
||||
|
||||
|
||||
# Verify content
|
||||
assert test_request.user == "test_user"
|
||||
assert test_request.collection == "test_collection"
|
||||
assert "customers" in test_request.query
|
||||
assert test_request.variables["status"] == "active"
|
||||
|
|
@ -53,15 +49,13 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test RowsQueryRequest with minimal required fields"""
|
||||
# Create request with only essential fields
|
||||
minimal_request = RowsQueryRequest(
|
||||
user="user",
|
||||
collection="collection",
|
||||
query='{ test }',
|
||||
variables={},
|
||||
operation_name=""
|
||||
)
|
||||
|
||||
|
||||
# Verify minimal request is valid
|
||||
assert minimal_request.user == "user"
|
||||
assert minimal_request.collection == "collection"
|
||||
assert minimal_request.query == '{ test }'
|
||||
assert minimal_request.variables == {}
|
||||
|
|
@ -187,22 +181,20 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test that request/response can be serialized/deserialized correctly"""
|
||||
# Create original request
|
||||
original_request = RowsQueryRequest(
|
||||
user="serialization_test",
|
||||
collection="test_data",
|
||||
query='{ orders(limit: 5) { id total customer { name } } }',
|
||||
variables={"limit": "5", "status": "active"},
|
||||
operation_name="GetRecentOrders"
|
||||
)
|
||||
|
||||
|
||||
# Test request serialization using Pulsar schema
|
||||
request_schema = AvroSchema(RowsQueryRequest)
|
||||
|
||||
|
||||
# Encode and decode request
|
||||
encoded_request = request_schema.encode(original_request)
|
||||
decoded_request = request_schema.decode(encoded_request)
|
||||
|
||||
|
||||
# Verify request round-trip
|
||||
assert decoded_request.user == original_request.user
|
||||
assert decoded_request.collection == original_request.collection
|
||||
assert decoded_request.query == original_request.query
|
||||
assert decoded_request.variables == original_request.variables
|
||||
|
|
@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
"""Test supported GraphQL query formats"""
|
||||
# Test basic query
|
||||
basic_query = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ customers { id } }',
|
||||
collection="test", query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
assert "customers" in basic_query.query
|
||||
|
|
@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test query with variables
|
||||
parameterized_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
|
||||
variables={"status": "active", "limit": "10"},
|
||||
operation_name="GetCustomers"
|
||||
|
|
@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test complex nested query
|
||||
nested_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='''
|
||||
{
|
||||
customers(limit: 10) {
|
||||
|
|
@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
# This test verifies the current contract, though ideally we'd support all JSON types
|
||||
|
||||
variables_test = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ test }',
|
||||
collection="test", query='{ test }',
|
||||
variables={
|
||||
"string_var": "test_value",
|
||||
"numeric_var": "123", # Numbers as strings due to Map(String()) limitation
|
||||
|
|
@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
def test_cassandra_context_fields_contract(self):
|
||||
"""Test that request contains necessary fields for Cassandra operations"""
|
||||
# Verify request has fields needed for Cassandra keyspace/table targeting
|
||||
# Verify request has fields needed for partition key targeting
|
||||
request = RowsQueryRequest(
|
||||
user="keyspace_name", # Maps to Cassandra keyspace
|
||||
collection="partition_collection", # Used in partition key
|
||||
query='{ objects { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
|
||||
# These fields are required for proper Cassandra operations
|
||||
assert request.user # Required for keyspace identification
|
||||
assert request.collection # Required for partition key
|
||||
|
||||
|
||||
# Required for partition key
|
||||
assert request.collection
|
||||
|
||||
# Verify field naming follows TrustGraph patterns (matching other query services)
|
||||
# This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns
|
||||
assert hasattr(request, 'user') # Same as TriplesQueryRequest.user
|
||||
assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection
|
||||
assert hasattr(request, 'collection')
|
||||
|
||||
def test_graphql_extensions_contract(self):
|
||||
"""Test GraphQL extensions field format and usage"""
|
||||
|
|
@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Request to execute specific operation
|
||||
multi_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query=multi_op_query,
|
||||
variables={},
|
||||
operation_name="GetCustomers"
|
||||
|
|
@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts:
|
|||
|
||||
# Test single operation (operation_name optional)
|
||||
single_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
collection="test",
|
||||
query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
|
|
|
|||
74
tests/contract/test_schema_field_contracts.py
Normal file
74
tests/contract/test_schema_field_contracts.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Contract tests for schema dataclass field sets.
|
||||
|
||||
These pin the *field names* of small, widely-constructed schema dataclasses
|
||||
so that any rename, removal, or accidental addition fails CI loudly instead
|
||||
of waiting for a runtime TypeError on the next websocket message.
|
||||
|
||||
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
|
||||
field but several call sites kept passing `Metadata(metadata=...)`. The bug
|
||||
was only discovered when a websocket import dispatcher received its first
|
||||
real message in production. A trivial structural assertion of the kind
|
||||
below would have caught it at unit-test time.
|
||||
|
||||
Add to this file whenever a schema rename burns you. The cost of a frozen
|
||||
field set is a one-line update when you intentionally evolve the schema; the
|
||||
benefit is that every call site is forced to come along for the ride.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import (
|
||||
Metadata,
|
||||
EntityContext,
|
||||
EntityEmbeddings,
|
||||
ChunkEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
def _field_names(dc):
|
||||
return {f.name for f in dataclasses.fields(dc)}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSchemaFieldContracts:
|
||||
"""Pin the field set of dataclasses that get constructed all over the
|
||||
codebase. If you intentionally change one of these, update the
|
||||
expected set in the same commit — that diff will surface every call
|
||||
site that needs to come along."""
|
||||
|
||||
def test_metadata_fields(self):
|
||||
# NOTE: there is no `metadata` field. A previous regression
|
||||
# constructed Metadata(metadata=...) and crashed at runtime.
|
||||
# `user` was also dropped in the workspace refactor — workspace
|
||||
# now flows via flow.workspace, not via message payload.
|
||||
assert _field_names(Metadata) == {
|
||||
"id",
|
||||
"root",
|
||||
"collection",
|
||||
}
|
||||
|
||||
def test_entity_embeddings_fields(self):
|
||||
# NOTE: the embedding field is `vector` (singular, list[float]).
|
||||
# There is no `vectors` field. Several call sites historically
|
||||
# passed `vectors=` and crashed at runtime.
|
||||
assert _field_names(EntityEmbeddings) == {
|
||||
"entity",
|
||||
"vector",
|
||||
"chunk_id",
|
||||
}
|
||||
|
||||
def test_chunk_embeddings_fields(self):
|
||||
# Same `vector` (singular) convention as EntityEmbeddings.
|
||||
assert _field_names(ChunkEmbeddings) == {
|
||||
"chunk_id",
|
||||
"vector",
|
||||
}
|
||||
|
||||
def test_entity_context_fields(self):
|
||||
assert _field_names(EntityContext) == {
|
||||
"entity",
|
||||
"context",
|
||||
"chunk_id",
|
||||
}
|
||||
|
|
@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="structured-data-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-obj-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-empty-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="struct-embed-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_structured_data_submission_serialization(self):
|
||||
"""Test StructuredDataSubmission serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
submission_data = {
|
||||
"metadata": metadata,
|
||||
"format": "json",
|
||||
|
|
@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_serialization(self):
|
||||
"""Test ExtractedObject serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_batch_serialization(self):
|
||||
"""Test ExtractedObject batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
batch_object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_empty_batch_serialization(self):
|
||||
"""Test ExtractedObject empty batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
metadata = Metadata(id="test", collection="col")
|
||||
empty_batch_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="4",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
|
|
@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="I need to solve this.",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test thought message
|
||||
thought_response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="Processing...",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test observation message
|
||||
observation_response = AgentResponse(
|
||||
chunk_type="observation",
|
||||
message_type="observation",
|
||||
content="Result found",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Streaming format with end_of_dialog=True
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
|
|
|
|||
|
|
@ -418,55 +418,55 @@ def sample_streaming_agent_response():
|
|||
"""Sample streaming agent response chunks"""
|
||||
return [
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "I need to search",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": " for information",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": " about machine learning.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "action",
|
||||
"message_type": "action",
|
||||
"content": "knowledge_query",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": "Machine learning is",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": " a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": "Machine learning",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": " is a subset",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": " of artificial intelligence.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True
|
||||
|
|
@ -494,10 +494,10 @@ def streaming_chunk_collector():
|
|||
"""Concatenate all chunk content"""
|
||||
return "".join(self.chunks)
|
||||
|
||||
def get_chunk_types(self):
|
||||
def get_message_types(self):
|
||||
"""Get list of chunk types if chunks are dicts"""
|
||||
if self.chunks and isinstance(self.chunks[0], dict):
|
||||
return [c.get("chunk_type") for c in self.chunks]
|
||||
return [c.get("message_type") for c in self.chunks]
|
||||
return []
|
||||
|
||||
def verify_streaming_protocol(self):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
|
|||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
|
||||
)
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
text_completion_client.question.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Machine learning involves algorithms that improve through experience."
|
||||
)
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
|
|
@ -147,8 +154,11 @@ Args: {
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -284,11 +300,14 @@ Args: {{
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -308,11 +327,13 @@ Args: {
|
|||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "Tool execution failed" in str(exc_info.value)
|
||||
# Act - tool errors are now caught and returned as observations
|
||||
result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert - error captured on the action, not raised
|
||||
assert result.tool_error is not None
|
||||
assert "Tool execution failed" in result.tool_error
|
||||
assert "Error:" in result.observation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
|
||||
|
|
@ -321,11 +342,14 @@ Args: {
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -372,9 +396,12 @@ Args: {
|
|||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -507,15 +534,17 @@ Args: {
|
|||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=test_case["response"]
|
||||
)
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent_manager.reason("test question", [], mock_flow_context)
|
||||
|
||||
assert "Failed to parse agent response" in str(exc_info.value)
|
||||
assert test_case["error_contains"] in str(exc_info.value)
|
||||
# Parse errors now return an Action with tool_error
|
||||
result = await agent_manager.reason("test question", [], mock_flow_context)
|
||||
assert isinstance(result, Action)
|
||||
assert result.name == "__parse_error__"
|
||||
assert result.tool_error is not None
|
||||
else:
|
||||
# Should succeed
|
||||
action = await agent_manager.reason("test question", [], mock_flow_context)
|
||||
|
|
@ -527,13 +556,16 @@ Args: {
|
|||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -541,15 +573,18 @@ Args: {
|
|||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -560,7 +595,9 @@ Args: {
|
|||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
|
|
@ -568,6 +605,7 @@ Action: knowledge_query
|
|||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -575,13 +613,16 @@ Args: {
|
|||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -593,13 +634,16 @@ This covers all aspects of the question."""
|
|||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -608,7 +652,9 @@ Args: {
|
|||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
|
|
@ -621,6 +667,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -632,7 +679,9 @@ Args: {
|
|||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
|
|
@ -642,6 +691,7 @@ Final Answer: {
|
|||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -792,11 +842,14 @@ Final Answer: {
|
|||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||
from trustgraph.agent.react.types import Tool, Argument
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_agent_streaming_chunks,
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
assert_chunk_types_valid,
|
||||
assert_message_types_valid,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -51,10 +52,10 @@ Args: {
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.agent_react.side_effect = agent_react_streaming
|
||||
return client
|
||||
|
|
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
return PromptResult(response_type="text", text=response)
|
||||
return PromptResult(response_type="text", text=response)
|
||||
|
||||
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -57,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
|
|||
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
|
||||
"""Test basic agent integration with structured query tool"""
|
||||
# Arrange - Load tool configuration
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Create agent request
|
||||
request = AgentRequest(
|
||||
|
|
@ -65,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
|
|||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -95,11 +95,14 @@ class TestAgentStructuredQueryIntegration:
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -115,6 +118,7 @@ Args: {
|
|||
# Mock flow parameter in agent_processor.on_request
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -142,14 +146,13 @@ Args: {
|
|||
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent handling of structured query errors"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -173,11 +176,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -192,6 +198,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -214,14 +221,13 @@ Args: {
|
|||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent using structured query in multi-step reasoning"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -250,11 +256,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -269,6 +278,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -303,14 +313,13 @@ Args: {
|
|||
}
|
||||
}
|
||||
|
||||
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
|
||||
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -339,11 +348,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -358,6 +370,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -381,10 +394,10 @@ Args: {
|
|||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query tool arguments are properly validated"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Check that the tool was registered with correct arguments
|
||||
tools = agent_processor.agent.tools
|
||||
tools = agent_processor.agents["default"].tools
|
||||
assert "structured-query" in tools
|
||||
|
||||
structured_tool = tools["structured-query"]
|
||||
|
|
@ -401,14 +414,13 @@ Args: {
|
|||
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query results are properly formatted for agent consumption"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -447,11 +459,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -466,6 +481,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
|
|||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Mock query to trigger TrustGraph creation
|
||||
mock_query = MagicMock()
|
||||
mock_query.user = 'default_user'
|
||||
mock_query.collection = 'default_collection'
|
||||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
|
|
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
await processor.query_triples(mock_query)
|
||||
await processor.query_triples('default_user', mock_query)
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('legacy_user', mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('precedence_user', mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('single_user', mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create test message
|
||||
storage_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
|
|
@ -178,7 +178,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Store test data for querying
|
||||
query_test_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
|
|
@ -212,7 +212,6 @@ class TestCassandraIntegration:
|
|||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
s_results = await query_processor.query_triples(s_query)
|
||||
|
|
@ -232,7 +231,6 @@ class TestCassandraIntegration:
|
|||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
p_results = await query_processor.query_triples(p_query)
|
||||
|
|
@ -259,7 +257,7 @@ class TestCassandraIntegration:
|
|||
# Create multiple coroutines for concurrent storage
|
||||
async def store_person_data(person_id, name, age, department):
|
||||
message = Triples(
|
||||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
metadata=Metadata(collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
|
|
@ -329,7 +327,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create a knowledge graph about a company
|
||||
company_graph = Triples(
|
||||
metadata=Metadata(user="integration_test", collection="company"),
|
||||
metadata=Metadata(collection="company"),
|
||||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
|
|||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
)
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -93,7 +99,6 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
result = await document_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
|
@ -104,7 +109,6 @@ class TestDocumentRagIntegration:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
|
@ -119,6 +123,7 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Verify final response
|
||||
result, usage = result
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
|
|
@ -131,7 +136,11 @@ class TestDocumentRagIntegration:
|
|||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="I couldn't find any relevant documents for your query."
|
||||
)
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -152,7 +161,8 @@ class TestDocumentRagIntegration:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
result_text, usage = result
|
||||
assert result_text == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
|
|
@ -266,14 +276,12 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -341,6 +349,5 @@ class TestDocumentRagIntegration:
|
|||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -104,7 +107,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -119,11 +121,12 @@ class TestDocumentRagStreaming:
|
|||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
result_text, usage = result
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert result_text == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert len(result) > 0
|
||||
assert len(result_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||
|
|
@ -137,7 +140,6 @@ class TestDocumentRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=False
|
||||
|
|
@ -151,7 +153,6 @@ class TestDocumentRagStreaming:
|
|||
|
||||
streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
|
|
@ -159,9 +160,11 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
non_streaming_text, _ = non_streaming_result
|
||||
streaming_text, _ = streaming_result
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||
|
|
@ -172,7 +175,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -180,8 +182,9 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
result_text, usage = result
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert result_text is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -193,7 +196,6 @@ class TestDocumentRagStreaming:
|
|||
# Arrange & Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -202,7 +204,8 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
result_text, usage = result
|
||||
assert isinstance(result_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||
|
|
@ -215,7 +218,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -223,7 +225,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -238,7 +241,6 @@ class TestDocumentRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -263,7 +265,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=limit,
|
||||
streaming=True,
|
||||
|
|
@ -271,7 +272,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
# Verify doc_limit was passed correctly
|
||||
|
|
@ -290,7 +292,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -299,5 +300,4 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Verify user/collection were passed to document embeddings client
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
|
|||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
|
@ -142,7 +146,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
@ -159,7 +162,6 @@ class TestGraphRagIntegration:
|
|||
call_args = mock_graph_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['limit'] == entity_limit
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
|
|
@ -169,6 +171,7 @@ class TestGraphRagIntegration:
|
|||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
|
@ -199,7 +202,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=config["entity_limit"],
|
||||
triple_limit=config["triple_limit"]
|
||||
|
|
@ -219,7 +221,6 @@ class TestGraphRagIntegration:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -242,7 +243,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
|
@ -262,7 +262,6 @@ class TestGraphRagIntegration:
|
|||
# First query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -272,7 +271,6 @@ class TestGraphRagIntegration:
|
|||
# Second identical query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -284,26 +282,27 @@ class TestGraphRagIntegration:
|
|||
assert second_call_count >= 0 # Should complete without errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different users/collections are properly isolated"""
|
||||
async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different collections propagate through to the embeddings query.
|
||||
|
||||
Workspace isolation is enforced by flow.workspace at the service
|
||||
boundary — not by parameters on GraphRag.query — so this test
|
||||
verifies collection routing only.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
user1, collection1 = "user1", "collection1"
|
||||
user2, collection2 = "user2", "collection2"
|
||||
collection1 = "collection1"
|
||||
collection2 = "collection2"
|
||||
|
||||
# Act
|
||||
await graph_rag.query(query=query, user=user1, collection=collection1)
|
||||
await graph_rag.query(query=query, user=user2, collection=collection2)
|
||||
await graph_rag.query(query=query, collection=collection1)
|
||||
await graph_rag.query(query=query, collection=collection2)
|
||||
|
||||
# Assert - Both users should have separate queries
|
||||
# Assert - Each call propagated its collection
|
||||
assert mock_graph_embeddings_client.query.call_count == 2
|
||||
|
||||
# Verify first call
|
||||
first_call = mock_graph_embeddings_client.query.call_args_list[0]
|
||||
assert first_call.kwargs['user'] == user1
|
||||
assert first_call.kwargs['collection'] == collection1
|
||||
|
||||
# Verify second call
|
||||
second_call = mock_graph_embeddings_client.query.call_args_list[1]
|
||||
assert second_call.kwargs['user'] == user2
|
||||
assert second_call.kwargs['collection'] == collection2
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
|
|||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -115,7 +116,6 @@ class TestGraphRagStreaming:
|
|||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect,
|
||||
|
|
@ -123,6 +123,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
response, usage = response
|
||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
|
|
@ -152,7 +153,6 @@ class TestGraphRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=False
|
||||
)
|
||||
|
|
@ -165,16 +165,17 @@ class TestGraphRagStreaming:
|
|||
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_response == non_streaming_response
|
||||
non_streaming_text, _ = non_streaming_response
|
||||
streaming_text, _ = streaming_response
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -185,7 +186,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -205,7 +205,6 @@ class TestGraphRagStreaming:
|
|||
# Arrange & Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=None # No callback provided
|
||||
|
|
@ -213,7 +212,8 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
response_text, usage = response
|
||||
assert isinstance(response_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
@ -226,7 +226,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -248,7 +247,6 @@ class TestGraphRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -268,7 +266,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
triples_obj = Triples(
|
||||
metadata=Metadata(
|
||||
id=f"export-msg-{i}",
|
||||
user=msg_data["metadata"]["user"],
|
||||
collection=msg_data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(msg_data["triples"]),
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
|
|||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
prompt_client.extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
prompt_client.extract_relationships.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
|
|
@ -90,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
|
|
@ -240,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -298,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -368,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_triples = Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -383,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_triples
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_triples(mock_msg, None, None)
|
||||
await processor.on_triples(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
|
||||
mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
|
||||
|
|
@ -400,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -414,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_graph_embeddings(mock_msg, None, None)
|
||||
await processor.on_graph_embeddings(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
|
||||
|
|
@ -489,7 +497,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -510,7 +521,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="invalid format"
|
||||
) # Should be jsonl with objects list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -528,16 +542,19 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
metadata=Metadata(id="test", collection="collection"),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
|
|
@ -564,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
|
||||
metadata=Metadata(id=f"doc-{i}", collection="collection"),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
|
|
@ -601,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
original_metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -630,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
|
||||
|
||||
assert triples_call.metadata.id == "test-doc-123"
|
||||
assert triples_call.metadata.user == "test_user"
|
||||
assert triples_call.metadata.collection == "test_collection"
|
||||
|
||||
assert entity_contexts_call.metadata.id == "test-doc-123"
|
||||
assert entity_contexts_call.metadata.user == "test_user"
|
||||
assert entity_contexts_call.metadata.collection == "test_collection"
|
||||
|
|
@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Set up schemas
|
||||
proc.schemas = sample_schemas
|
||||
proc.schemas = {"default": dict(sample_schemas)}
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
|
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
|
|||
}
|
||||
|
||||
# Act - Update configuration
|
||||
await integration_processor.on_schema_config(new_schema_config, "v2")
|
||||
await integration_processor.on_schema_config("default", new_schema_config, "v2")
|
||||
|
||||
# Arrange - Test query using new schema
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
|
|||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert "inventory" in integration_processor.schemas
|
||||
assert "inventory" in integration_processor.schemas["default"]
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.detected_schemas == ["inventory"]
|
||||
|
|
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
|
|||
graphql_generation_template="custom-graphql-generator"
|
||||
)
|
||||
|
||||
custom_processor.schemas = sample_schemas
|
||||
custom_processor.schemas = {"default": dict(sample_schemas)}
|
||||
custom_processor.client = MagicMock()
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
|
|||
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
|
||||
)
|
||||
|
||||
integration_processor.schemas.update(large_schema_set)
|
||||
integration_processor.schemas["default"].update(large_schema_set)
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me data from table_05 and table_12",
|
||||
|
|
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
|
|||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
|
|||
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
|
||||
if schema_name == "customer_records":
|
||||
if "john" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "jane" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
elif schema_name == "product_catalog":
|
||||
if "laptop" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "book" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
prompt_client.extract_objects.side_effect = mock_extract_objects
|
||||
|
||||
|
|
@ -172,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -184,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
# Verify customer schema
|
||||
customer_schema = processor.schemas["customer_records"]
|
||||
customer_schema = ws_schemas["customer_records"]
|
||||
assert customer_schema.name == "customer_records"
|
||||
assert len(customer_schema.fields) == 4
|
||||
|
||||
|
||||
# Verify product schema
|
||||
product_schema = processor.schemas["product_catalog"]
|
||||
product_schema = ws_schemas["product_catalog"]
|
||||
assert product_schema.name == "product_catalog"
|
||||
assert len(product_schema.fields) == 4
|
||||
|
||||
|
|
@ -224,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic customer data chunk
|
||||
metadata = Metadata(
|
||||
id="customer-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -291,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic product data chunk
|
||||
metadata = Metadata(
|
||||
id="product-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -355,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create multiple test chunks
|
||||
chunks_data = [
|
||||
|
|
@ -369,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
for chunk_id, text in chunks_data:
|
||||
metadata = Metadata(
|
||||
id=chunk_id,
|
||||
user="concurrent_test",
|
||||
collection="test_collection",
|
||||
)
|
||||
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
|
||||
|
|
@ -418,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
"customer_records": integration_config["schema"]["customer_records"]
|
||||
}
|
||||
}
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
|
||||
assert len(processor.schemas) == 1
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" not in processor.schemas
|
||||
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 1
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" not in ws_schemas
|
||||
|
||||
# Act - Reload with full configuration
|
||||
await processor.on_schema_config(integration_config, version=2)
|
||||
|
||||
await processor.on_schema_config("default", integration_config, version=2)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_resilience_integration(self, integration_config):
|
||||
|
|
@ -461,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
failing_flow.side_effect = failing_context_router
|
||||
failing_flow.workspace = "default"
|
||||
processor.flow = failing_flow
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create test chunk
|
||||
metadata = Metadata(id="error-test", user="test", collection="test")
|
||||
metadata = Metadata(id="error-test", collection="test")
|
||||
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -497,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create chunk with rich metadata
|
||||
original_metadata = Metadata(
|
||||
id="metadata-test-chunk",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -531,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert extracted_obj is not None
|
||||
|
||||
# Verify metadata propagation
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
assert extracted_obj.metadata.collection == "test_collection"
|
||||
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference
|
||||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.prompt.template.service import Processor
|
||||
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||
from trustgraph.base.text_completion_client import TextCompletionResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -27,34 +28,52 @@ class TestPromptStreaming:
|
|||
# Mock text completion client with streaming
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def streaming_request(request, recipient=None, timeout=600):
|
||||
"""Simulate streaming text completion"""
|
||||
if request.streaming and recipient:
|
||||
# Simulate streaming chunks
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
# Streaming chunks to send
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=is_final
|
||||
)
|
||||
final = await recipient(response)
|
||||
if final:
|
||||
break
|
||||
|
||||
# Final empty chunk
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
|
||||
"""Simulate streaming text completion via text_completion_stream"""
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
end_of_stream=False
|
||||
)
|
||||
await handler(response)
|
||||
|
||||
text_completion_client.request = streaming_request
|
||||
# Send final empty chunk with end_of_stream
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
return TextCompletionResult(
|
||||
text=None,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, timeout=600):
|
||||
"""Simulate non-streaming text completion"""
|
||||
full_text = "Machine learning is a field of artificial intelligence."
|
||||
return TextCompletionResult(
|
||||
text=full_text,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=streaming_text_completion_stream
|
||||
)
|
||||
text_completion_client.text_completion = AsyncMock(
|
||||
side_effect=non_streaming_text_completion
|
||||
)
|
||||
|
||||
# Mock response producer
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -68,6 +87,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -90,7 +110,7 @@ class TestPromptStreaming:
|
|||
def prompt_processor_streaming(self, mock_prompt_manager):
|
||||
"""Create Prompt processor with streaming support"""
|
||||
processor = MagicMock()
|
||||
processor.manager = mock_prompt_manager
|
||||
processor.managers = {"default": mock_prompt_manager}
|
||||
processor.config_key = "prompt"
|
||||
|
||||
# Bind the actual on_request method
|
||||
|
|
@ -156,14 +176,6 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock non-streaming text completion
|
||||
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||
return "AI is the simulation of human intelligence in machines."
|
||||
|
||||
text_completion_client.text_completion = non_streaming_text_completion
|
||||
|
||||
# Act
|
||||
await prompt_processor_streaming.on_request(
|
||||
message, consumer, mock_flow_context_streaming
|
||||
|
|
@ -218,17 +230,12 @@ class TestPromptStreaming:
|
|||
# Mock text completion client that raises an error
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def failing_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Send error response with proper Error schema
|
||||
error_response = TextCompletionResponse(
|
||||
response="",
|
||||
error=Error(message="Text completion error", type="processing_error"),
|
||||
end_of_stream=True
|
||||
)
|
||||
await recipient(error_response)
|
||||
async def failing_stream(system, prompt, handler, timeout=600):
|
||||
raise RuntimeError("Text completion error")
|
||||
|
||||
text_completion_client.request = failing_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=failing_stream
|
||||
)
|
||||
|
||||
# Mock response producer to capture error response
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -242,6 +249,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
@ -255,22 +263,15 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Act - The service catches errors and sends error responses, doesn't raise
|
||||
# Act - The service catches errors and sends an error PromptResponse
|
||||
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||
|
||||
# Assert - Verify error response was sent
|
||||
assert response_producer.send.call_count > 0
|
||||
|
||||
# Check that at least one response contains an error
|
||||
error_sent = False
|
||||
for call in response_producer.send.call_args_list:
|
||||
response = call.args[0]
|
||||
if hasattr(response, 'error') and response.error:
|
||||
error_sent = True
|
||||
assert "Text completion error" in response.error.message
|
||||
break
|
||||
|
||||
assert error_sent, "Expected error response to be sent"
|
||||
# Assert - error response was sent
|
||||
calls = response_producer.send.call_args_list
|
||||
assert len(calls) > 0
|
||||
error_response = calls[-1].args[0]
|
||||
assert error_response.error is not None
|
||||
assert "Text completion error" in error_response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||
|
|
@ -315,21 +316,22 @@ class TestPromptStreaming:
|
|||
# Mock text completion that sends empty chunks
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||
if request.streaming and recipient:
|
||||
# Send empty chunk followed by final marker
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
async def empty_streaming(system, prompt, handler, timeout=600):
|
||||
# Send empty chunk followed by final marker
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
text_completion_client.request = empty_streaming_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=empty_streaming
|
||||
)
|
||||
response_producer = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
|
|
@ -341,6 +343,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
@ -401,4 +404,4 @@ class TestPromptStreaming:
|
|||
|
||||
# Verify chunks concatenate to expected result
|
||||
full_text = "".join(chunk_texts)
|
||||
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||
assert full_text == "Machine learning is a field of artificial intelligence."
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
|||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
|
|||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
|
|
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
|
|||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="The answer is here.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -237,11 +232,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "Document summary."
|
||||
return PromptResult(response_type="text", text="Document summary.")
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -265,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -288,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -312,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -334,17 +328,17 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="textmore")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
|
|
|
|||
|
|
@ -14,10 +14,35 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_async_execute(self):
|
||||
"""Route async_execute through session.execute so the mock's
|
||||
side_effect handles all CQL (DDL and DML) uniformly and every
|
||||
call lands in mock_session.execute.call_args_list."""
|
||||
async def _fake(session, query, params=None):
|
||||
session.execute(query, params)
|
||||
return []
|
||||
with patch(
|
||||
'trustgraph.storage.rows.cassandra.write.async_execute',
|
||||
new=_fake,
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
|
|
@ -111,14 +136,13 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
),
|
||||
schema_name="customer_records",
|
||||
|
|
@ -135,7 +159,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
|
@ -144,7 +168,7 @@ class TestRowsCassandraIntegration:
|
|||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
assert "default" in str(keyspace_calls[0])
|
||||
|
||||
# Verify unified table creation (rows table, not per-schema table)
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -195,12 +219,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert len(processor.schemas["default"]) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog"),
|
||||
metadata=Metadata(id="p1", collection="catalog"),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -208,7 +232,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales"),
|
||||
metadata=Metadata(id="o1", collection="sales"),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
|
|
@ -219,7 +243,7 @@ class TestRowsCassandraIntegration:
|
|||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -242,18 +266,20 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"indexed_data": RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
|
|
@ -268,7 +294,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -328,13 +354,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
|
|
@ -362,7 +387,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -382,14 +407,16 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"empty_test": RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty"),
|
||||
metadata=Metadata(id="empty-1", collection="empty"),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
|
|
@ -399,7 +426,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
|
|
@ -414,17 +441,19 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"map_test": RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -434,7 +463,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -459,16 +488,18 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"partition_test": RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection"),
|
||||
metadata=Metadata(id="t1", collection="my_collection"),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -478,7 +509,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
|
||||
"""Test schema configuration loading and GraphQL schema generation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Verify schemas were loaded
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
|
||||
"""Test Cassandra connection and dynamic table creation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
|
||||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
|
|
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
|
||||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_error_handling(self, processor, sample_schema_config):
|
||||
"""Test GraphQL error handling for invalid queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Test invalid field query
|
||||
invalid_query = '''
|
||||
|
|
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_message_processing_integration(self, processor, sample_schema_config):
|
||||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
|
|
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_concurrent_queries(self, processor, sample_schema_config):
|
||||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
|
|
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
assert len(processor.schemas) == 1
|
||||
assert "simple" in processor.schemas
|
||||
|
||||
|
|
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(updated_config, version=2)
|
||||
await processor.on_schema_config("default", updated_config, version=2)
|
||||
|
||||
# Verify updated schemas
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_large_result_set_handling(self, processor, sample_schema_config):
|
||||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
|
|
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Measure query execution time
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
# Arrange - Create realistic query request
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from California who have made purchases over $500",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
assert objects_call_args.variables["state"] == "California"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
[pytest]
|
||||
testpaths = tests
|
||||
python_paths = .
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--cov=trustgraph
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
# --disable-warnings
|
||||
# --cov-fail-under=80
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -78,10 +81,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session, iteration, observation, and final
|
||||
|
|
@ -93,7 +96,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check thought message
|
||||
thought_response = content_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.chunk_type == "thought"
|
||||
assert thought_response.message_type == "thought"
|
||||
assert thought_response.content == "I need to solve this."
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
|
@ -101,7 +104,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check observation message
|
||||
observation_response = content_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.chunk_type == "observation"
|
||||
assert observation_response.message_type == "observation"
|
||||
assert observation_response.content == "The answer is 4."
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
|
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -168,10 +174,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session and final
|
||||
|
|
@ -183,7 +189,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check final answer message
|
||||
answer_response = content_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.chunk_type == "answer"
|
||||
assert answer_response.message_type == "answer"
|
||||
assert answer_response.content == "4"
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
|
|
|||
|
|
@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
|
|||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
def _make_request(question="Test question",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
|
|
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
|
|||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
|
|
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
|
|
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
|
|||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
"unknown", "question", "default",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
|
|||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
assert responses[0].message_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
|
|
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
|
|||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
assert responses[0].message_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
|
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
|
|||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
assert responses[0].message_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
|
|||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
|
|||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
|
|
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
|
|||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
client.prompt = AsyncMock(
|
||||
return_value=PromptResult(response_type="text", text=prompt_response)
|
||||
)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
|
@ -274,8 +277,8 @@ class TestRoute:
|
|||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
return PromptResult(response_type="text", text="research")
|
||||
return PromptResult(response_type="text", text="plan-then-execute")
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
|||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
|
|||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"explain_graph": resp.explain_graph,
|
||||
|
|
@ -125,7 +126,6 @@ def make_base_request(**kwargs):
|
|||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
|
|
@ -183,7 +183,7 @@ class TestReactPatternProvenance:
|
|||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
|
|
@ -267,7 +267,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
|
@ -309,7 +309,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
|
|
@ -355,10 +355,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -418,10 +421,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="json",
|
||||
object={
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
},
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -475,7 +481,7 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -542,10 +548,13 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -590,7 +599,7 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -639,7 +648,10 @@ class TestSupervisorPatternProvenance:
|
|||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B", "Goal C"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
|
|
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
|
|
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
|
|
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
|
|
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class MockProcessor:
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ tool usage patterns.
|
|||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ class TestToolCoordinationLogic:
|
|||
resolved_params[key] = value
|
||||
|
||||
# Execute tool
|
||||
if asyncio.iscoroutinefunction(tool_function):
|
||||
if inspect.iscoroutinefunction(tool_function):
|
||||
result = await tool_function(**resolved_params)
|
||||
else:
|
||||
result = tool_function(**resolved_params)
|
||||
|
|
@ -227,7 +228,7 @@ class TestToolCoordinationLogic:
|
|||
# Simulate async execution with delay
|
||||
await asyncio.sleep(0.001) # Small delay to simulate work
|
||||
|
||||
if asyncio.iscoroutinefunction(tool_function):
|
||||
if inspect.iscoroutinefunction(tool_function):
|
||||
result = await tool_function(**parameters)
|
||||
else:
|
||||
result = tool_function(**parameters)
|
||||
|
|
@ -337,7 +338,7 @@ class TestToolCoordinationLogic:
|
|||
if attempt > 0:
|
||||
await asyncio.sleep(0.001 * (self.backoff_factor ** attempt))
|
||||
|
||||
if asyncio.iscoroutinefunction(tool_function):
|
||||
if inspect.iscoroutinefunction(tool_function):
|
||||
result = await tool_function(**parameters)
|
||||
else:
|
||||
result = tool_function(**parameters)
|
||||
|
|
|
|||
|
|
@ -167,39 +167,28 @@ class TestToolServiceRequest:
|
|||
"""Test cases for tool service request format"""
|
||||
|
||||
def test_request_format(self):
|
||||
"""Test that request is properly formatted with user, config, and arguments"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
"""Test that request is properly formatted with config and arguments"""
|
||||
config_values = {"style": "pun", "collection": "jokes"}
|
||||
arguments = {"topic": "programming"}
|
||||
|
||||
# Act - simulate request building
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values),
|
||||
"arguments": json.dumps(arguments)
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["user"] == "alice"
|
||||
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
|
||||
assert json.loads(request["arguments"]) == {"topic": "programming"}
|
||||
|
||||
def test_request_with_empty_config(self):
|
||||
"""Test request when no config values are provided"""
|
||||
# Arrange
|
||||
user = "bob"
|
||||
config_values = {}
|
||||
arguments = {"query": "test"}
|
||||
|
||||
# Act
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values) if config_values else "{}",
|
||||
"arguments": json.dumps(arguments) if arguments else "{}"
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["config"] == "{}"
|
||||
assert json.loads(request["arguments"]) == {"query": "test"}
|
||||
|
||||
|
|
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
|
|||
assert map_topic_to_category("random topic") == "default"
|
||||
assert map_topic_to_category("") == "default"
|
||||
|
||||
def test_joke_response_personalization(self):
|
||||
"""Test that joke responses include user personalization"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
def test_joke_response_format(self):
|
||||
"""Test that joke response is formatted as expected"""
|
||||
style = "pun"
|
||||
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
|
||||
|
||||
# Act
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
response = f"Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
# Assert
|
||||
assert "Hey alice!" in response
|
||||
assert "pun" in response
|
||||
assert joke in response
|
||||
|
||||
|
|
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
|
|||
|
||||
def test_request_parsing(self):
|
||||
"""Test parsing of incoming request"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
"user": "alice",
|
||||
"config": '{"style": "pun"}',
|
||||
"arguments": '{"topic": "programming"}'
|
||||
}
|
||||
|
||||
# Act
|
||||
user = request_data.get("user", "trustgraph")
|
||||
config = json.loads(request_data["config"]) if request_data["config"] else {}
|
||||
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
|
||||
|
||||
# Assert
|
||||
assert user == "alice"
|
||||
assert config == {"style": "pun"}
|
||||
assert arguments == {"topic": "programming"}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Tests for tool service lifecycle, invoke contract, streaming responses,
|
||||
multi-tenancy, and error propagation.
|
||||
and error propagation.
|
||||
|
||||
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
|
||||
classes rather than plain dicts.
|
||||
|
|
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await svc.invoke("user", {}, {})
|
||||
await svc.invoke({}, {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_calls_invoke_with_parsed_args(self):
|
||||
|
|
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
calls = []
|
||||
|
||||
async def tracking_invoke(user, config, arguments):
|
||||
calls.append({"user": user, "config": config, "arguments": arguments})
|
||||
async def tracking_invoke(config, arguments):
|
||||
calls.append({"config": config, "arguments": arguments})
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking_invoke
|
||||
|
|
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user="alice",
|
||||
config='{"style": "pun"}',
|
||||
arguments='{"topic": "cats"}',
|
||||
)
|
||||
|
|
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
|
|||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["user"] == "alice"
|
||||
assert calls[0]["config"] == {"style": "pun"}
|
||||
assert calls[0]["arguments"] == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_empty_user_defaults_to_trustgraph(self):
|
||||
"""Empty user field should default to 'trustgraph'."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
received_user = None
|
||||
|
||||
async def capture_invoke(user, config, arguments):
|
||||
nonlocal received_user
|
||||
received_user = user
|
||||
return "ok"
|
||||
|
||||
svc.invoke = capture_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
|
||||
msg.properties.return_value = {"id": "req-2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert received_user == "trustgraph"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_string_response_sent_directly(self):
|
||||
"""String return from invoke → response field is the string."""
|
||||
|
|
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def string_invoke(user, config, arguments):
|
||||
async def string_invoke(config, arguments):
|
||||
return "hello world"
|
||||
|
||||
svc.invoke = string_invoke
|
||||
|
|
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def dict_invoke(user, config, arguments):
|
||||
async def dict_invoke(config, arguments):
|
||||
return {"result": 42}
|
||||
|
||||
svc.invoke = dict_invoke
|
||||
|
|
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def failing_invoke(user, config, arguments):
|
||||
async def failing_invoke(config, arguments):
|
||||
raise ValueError("bad input")
|
||||
|
||||
svc.invoke = failing_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def rate_limited_invoke(user, config, arguments):
|
||||
async def rate_limited_invoke(config, arguments):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
svc.invoke = rate_limited_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r4"}
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
|
|
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def ok_invoke(user, config, arguments):
|
||||
async def ok_invoke(config, arguments):
|
||||
return "ok"
|
||||
|
||||
svc.invoke = ok_invoke
|
||||
|
|
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "unique-42"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
|
|
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
async def failing_invoke(workspace, name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
|
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
async def rate_limited(workspace, name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
|
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
|
|||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
flow.workspace = "default"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
|
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
|
|||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
async def capture_invoke(workspace, name, params):
|
||||
received["workspace"] = workspace
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
|
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
|
|||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
flow.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
|
|
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
result = await client.call(
|
||||
user="alice",
|
||||
config={"style": "pun"},
|
||||
arguments={"topic": "cats"},
|
||||
)
|
||||
|
|
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
|
|||
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, ToolServiceRequest)
|
||||
assert req.user == "alice"
|
||||
assert json.loads(req.config) == {"style": "pun"}
|
||||
assert json.loads(req.arguments) == {"topic": "cats"}
|
||||
|
||||
|
|
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await client.call(user="u", config={}, arguments={})
|
||||
await client.call(config={}, arguments={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_empty_config_sends_empty_json(self):
|
||||
|
|
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config=None, arguments=None)
|
||||
await client.call(config=None, arguments=None)
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.config == "{}"
|
||||
|
|
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config={}, arguments={}, timeout=30)
|
||||
await client.call(config={}, arguments={}, timeout=30)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 30
|
||||
|
|
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
assert result == "chunk1chunk2"
|
||||
|
|
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
|
|||
|
||||
with pytest.raises(RuntimeError, match="stream failed"):
|
||||
await client.call_streaming(
|
||||
user="u", config={}, arguments={},
|
||||
config={}, arguments={},
|
||||
callback=AsyncMock(),
|
||||
)
|
||||
|
||||
|
|
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
# Empty response is falsy, so callback shouldn't be called for it
|
||||
assert result == "data"
|
||||
assert received == ["data"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tenancy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiTenancy:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_propagated_to_invoke(self):
|
||||
"""User from request should reach the invoke method."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
users_seen = []
|
||||
|
||||
async def tracking(user, config, arguments):
|
||||
users_seen.append(user)
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
for u in ["tenant-a", "tenant-b", "tenant-c"]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user=u, config="{}", arguments="{}",
|
||||
)
|
||||
msg.properties.return_value = {"id": f"req-{u}"}
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_sends_user_in_request(self):
|
||||
"""ToolServiceClient.call should include user in request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="isolated-tenant", config={}, arguments={})
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.user == "isolated-tenant"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
- on_config_notify version comparison, type/workspace matching
|
||||
- fetch_and_apply_config retry logic over per-workspace fetches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
|
|
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
|
|||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
def _notify_msg(version, changes):
|
||||
"""Build a Mock config-notify message with given version and changes dict."""
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -77,9 +81,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(3, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -91,9 +93,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(5, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -105,9 +105,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
msg = _notify_msg(2, {"schema": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -121,40 +119,36 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={"key": "value"},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_called_once_with(
|
||||
"default", {"prompt": {"key": "value"}}, 2
|
||||
)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
async def test_handler_without_types_ignored_on_notify(self, processor):
|
||||
"""Handlers registered without types never fire on notifications."""
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
processor.register_config_handler(handler) # No types
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
msg = _notify_msg(2, {"whatever": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_not_called()
|
||||
# Version still advances past the notify
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
|
|
@ -168,156 +162,149 @@ class TestOnConfigNotify:
|
|||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
prompt_handler.assert_called_once_with(
|
||||
"default", {"prompt": {}}, 2
|
||||
)
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
all_handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
async def test_multi_workspace_notify_invokes_handler_per_ws(
|
||||
self, processor
|
||||
):
|
||||
"""Notify affecting multiple workspaces invokes handler once per workspace."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_config = {}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
assert handler.call_count == 2
|
||||
called_workspaces = {c.args[0] for c in handler.call_args_list}
|
||||
assert called_workspaces == {"ws1", "ws2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
side_effect=RuntimeError("Connection failed"),
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
async def test_applies_config_per_workspace(self, processor):
|
||||
"""Startup fetch invokes handler once per workspace affected."""
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {
|
||||
"ws1": {"k": "v1"},
|
||||
"ws2": {"k": "v2"},
|
||||
}, 10
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert h.call_count == 2
|
||||
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
|
||||
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
|
||||
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
async def test_handler_without_types_skipped_at_startup(self, processor):
|
||||
"""Handlers registered without types fetch nothing at startup."""
|
||||
typed = AsyncMock()
|
||||
untyped = AsyncMock()
|
||||
processor.register_config_handler(typed, types=["prompt"])
|
||||
processor.register_config_handler(untyped)
|
||||
|
||||
async def mock_fetch():
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {"default": {}}, 1
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
typed.assert_called_once()
|
||||
untyped.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
return {"default": {"k": "v"}}, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
), patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
h.assert_called_once_with(
|
||||
"default", {"prompt": {"k": "v"}}, 5
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
result = await client.query(
|
||||
vector=vector,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
timeout=30
|
||||
)
|
||||
|
|
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||
assert call_args.vector == vector
|
||||
assert call_args.limit == 10
|
||||
assert call_args.user == "test_user"
|
||||
assert call_args.collection == "test_collection"
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client.request.assert_called_once()
|
||||
call_args = client.request.call_args[0][0]
|
||||
assert call_args.limit == 20 # Default limit
|
||||
assert call_args.user == "trustgraph" # Default user
|
||||
assert call_args.collection == "default" # Default collection
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
|
|||
81
tests/unit/test_base/test_flow_base_modules.py
Normal file
81
tests/unit/test_base/test_flow_base_modules.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.base.flow import Flow
|
||||
from trustgraph.base.parameter_spec import Parameter, ParameterSpec
|
||||
from trustgraph.base.spec import Spec
|
||||
|
||||
|
||||
def test_parameter_spec_is_a_spec_and_adds_parameter_value():
|
||||
spec = ParameterSpec("temperature")
|
||||
flow = MagicMock(parameter={})
|
||||
processor = MagicMock()
|
||||
|
||||
spec.add(flow, processor, {"parameters": {"temperature": 0.7}})
|
||||
|
||||
assert isinstance(spec, Spec)
|
||||
assert "temperature" in flow.parameter
|
||||
assert isinstance(flow.parameter["temperature"], Parameter)
|
||||
assert flow.parameter["temperature"].value == 0.7
|
||||
|
||||
|
||||
def test_parameter_spec_defaults_missing_values_to_none():
|
||||
spec = ParameterSpec("model")
|
||||
flow = MagicMock(parameter={})
|
||||
|
||||
spec.add(flow, MagicMock(), {})
|
||||
|
||||
assert flow.parameter["model"].value is None
|
||||
|
||||
|
||||
def test_parameter_start_and_stop_are_awaitable():
|
||||
parameter = Parameter("value")
|
||||
|
||||
assert asyncio.run(parameter.start()) is None
|
||||
assert asyncio.run(parameter.stop()) is None
|
||||
|
||||
|
||||
def test_flow_initialization_calls_registered_specs():
|
||||
spec_one = MagicMock()
|
||||
spec_two = MagicMock()
|
||||
processor = MagicMock(specifications=[spec_one, spec_two])
|
||||
|
||||
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
|
||||
|
||||
assert flow.id == "processor-1"
|
||||
assert flow.name == "flow-a"
|
||||
assert flow.workspace == "default"
|
||||
assert flow.producer == {}
|
||||
assert flow.consumer == {}
|
||||
assert flow.parameter == {}
|
||||
spec_one.add.assert_called_once_with(flow, processor, {"answer": 42})
|
||||
spec_two.add.assert_called_once_with(flow, processor, {"answer": 42})
|
||||
|
||||
|
||||
def test_flow_start_and_stop_visit_all_consumers():
|
||||
consumer_one = AsyncMock()
|
||||
consumer_two = AsyncMock()
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.consumer = {"one": consumer_one, "two": consumer_two}
|
||||
|
||||
asyncio.run(flow.start())
|
||||
asyncio.run(flow.stop())
|
||||
|
||||
consumer_one.start.assert_called_once_with()
|
||||
consumer_two.start.assert_called_once_with()
|
||||
consumer_one.stop.assert_called_once_with()
|
||||
consumer_two.stop.assert_called_once_with()
|
||||
|
||||
|
||||
def test_flow_call_returns_values_in_priority_order():
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.producer["shared"] = "producer-value"
|
||||
flow.consumer["consumer-only"] = "consumer-value"
|
||||
flow.consumer["shared"] = "consumer-value"
|
||||
flow.parameter["parameter-only"] = Parameter("parameter-value")
|
||||
flow.parameter["shared"] = Parameter("parameter-value")
|
||||
|
||||
assert flow("shared") == "producer-value"
|
||||
assert flow("consumer-only") == "consumer-value"
|
||||
assert flow("parameter-only") == "parameter-value"
|
||||
assert flow("missing") is None
|
||||
|
|
@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
|||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
|
|
|
|||
|
|
@ -1,58 +1,50 @@
|
|||
"""
|
||||
Unit tests for trustgraph.base.flow_processor
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.base.flow_processor import FlowProcessor
|
||||
|
||||
|
||||
# Patches needed to let AsyncProcessor.__init__ run without real
|
||||
# infrastructure while still setting self.id correctly.
|
||||
ASYNC_PROCESSOR_PATCHES = [
|
||||
patch('trustgraph.base.async_processor.get_pubsub', return_value=MagicMock()),
|
||||
patch('trustgraph.base.async_processor.ProcessorMetrics', return_value=MagicMock()),
|
||||
patch('trustgraph.base.async_processor.Consumer', return_value=MagicMock()),
|
||||
]
|
||||
|
||||
|
||||
def with_async_processor_patches(func):
|
||||
"""Apply all AsyncProcessor dependency patches to a test."""
|
||||
for p in reversed(ASYNC_PROCESSOR_PATCHES):
|
||||
func = p(func)
|
||||
return func
|
||||
|
||||
|
||||
class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test FlowProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_flow_processor_initialization_basic(self, mock_register_config, mock_async_init):
|
||||
@with_async_processor_patches
|
||||
async def test_flow_processor_initialization_basic(self, *mocks):
|
||||
"""Test basic FlowProcessor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify AsyncProcessor.__init__ was called
|
||||
mock_async_init.assert_called_once()
|
||||
|
||||
# Verify register_config_handler was called with the correct handler
|
||||
mock_register_config.assert_called_once_with(
|
||||
processor.on_configure_flows, types=["active-flow"]
|
||||
)
|
||||
|
||||
# Verify FlowProcessor-specific initialization
|
||||
assert hasattr(processor, 'flows')
|
||||
assert processor.id == 'test-flow-processor'
|
||||
assert processor.flows == {}
|
||||
assert hasattr(processor, 'specifications')
|
||||
assert processor.specifications == []
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_register_specification(self, mock_register_config, mock_async_init):
|
||||
@with_async_processor_patches
|
||||
async def test_register_specification(self, *mocks):
|
||||
"""Test registering a specification"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
|
|
@ -62,288 +54,210 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_spec = MagicMock()
|
||||
mock_spec.name = 'test-spec'
|
||||
|
||||
# Act
|
||||
processor.register_specification(mock_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) == 1
|
||||
assert processor.specifications[0] == mock_spec
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_start_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_start_flow(self, *mocks):
|
||||
"""Test starting a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_flow_class = mocks[-1]
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor' # Set id for Flow creation
|
||||
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
# Assert
|
||||
assert flow_name in processor.flows
|
||||
# Verify Flow was created with correct parameters
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
# Verify the flow's start method was called
|
||||
assert ("default", flow_name) in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', flow_name, "default", processor, flow_defn
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_stop_flow(self, *mocks):
|
||||
"""Test stopping a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_flow_class = mocks[-1]
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
await processor.start_flow("default", flow_name, {'config': 'test-config'})
|
||||
|
||||
# Start a flow first
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Act
|
||||
await processor.stop_flow(flow_name)
|
||||
await processor.stop_flow("default", flow_name)
|
||||
|
||||
# Assert
|
||||
assert flow_name not in processor.flows
|
||||
assert ("default", flow_name) not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow_not_exists(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_stop_flow_not_exists(self, *mocks):
|
||||
"""Test stopping a flow that doesn't exist"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act - should not raise an exception
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
|
||||
# Assert - flows dict should still be empty
|
||||
await processor.stop_flow("default", 'non-existent-flow')
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_basic(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_on_configure_flows_basic(self, *mocks):
|
||||
"""Test basic flow configuration handling"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_flow_class = mocks[-1]
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
# Configuration with flows for this processor
|
||||
flow_config = {
|
||||
'test-flow': {'config': 'test-config'}
|
||||
}
|
||||
|
||||
config_data = {
|
||||
'active-flow': {
|
||||
'test-processor': '{"test-flow": {"config": "test-config"}}'
|
||||
'processor:test-processor': {
|
||||
'test-flow': '{"config": "test-config"}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert 'test-flow' in processor.flows
|
||||
mock_flow_class.assert_called_once_with('test-processor', 'test-flow', processor, {'config': 'test-config'})
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert ("default", 'test-flow') in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', 'test-flow', "default", processor,
|
||||
{'config': 'test-config'}
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_no_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_on_configure_flows_no_config(self, *mocks):
|
||||
"""Test flow configuration handling when no config exists for this processor"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without flows for this processor
|
||||
|
||||
config_data = {
|
||||
'active-flow': {
|
||||
'other-processor': '{"other-flow": {"config": "other-config"}}'
|
||||
'processor:other-processor': {
|
||||
'other-flow': '{"config": "other-config"}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_invalid_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_on_configure_flows_invalid_config(self, *mocks):
|
||||
"""Test flow configuration handling with invalid config format"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without active-flow key
|
||||
|
||||
config_data = {
|
||||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_start_and_stop(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
@with_async_processor_patches
|
||||
async def test_on_configure_flows_start_and_stop(self, *mocks):
|
||||
"""Test flow configuration handling with starting and stopping flows"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_flow_class = mocks[-1]
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'id': 'test-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
|
||||
mock_flow1 = AsyncMock()
|
||||
mock_flow2 = AsyncMock()
|
||||
mock_flow_class.side_effect = [mock_flow1, mock_flow2]
|
||||
|
||||
# First configuration - start flow1
|
||||
|
||||
config_data1 = {
|
||||
'active-flow': {
|
||||
'test-processor': '{"flow1": {"config": "config1"}}'
|
||||
'processor:test-processor': {
|
||||
'flow1': '{"config": "config1"}'
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
await processor.on_configure_flows("default", config_data1, version=1)
|
||||
|
||||
# Second configuration - stop flow1, start flow2
|
||||
config_data2 = {
|
||||
'active-flow': {
|
||||
'test-processor': '{"flow2": {"config": "config2"}}'
|
||||
'processor:test-processor': {
|
||||
'flow2': '{"config": "config2"}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
|
||||
# Assert
|
||||
# flow1 should be stopped and removed
|
||||
assert 'flow1' not in processor.flows
|
||||
await processor.on_configure_flows("default", config_data2, version=2)
|
||||
|
||||
assert ("default", 'flow1') not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
# flow2 should be started and added
|
||||
assert 'flow2' in processor.flows
|
||||
|
||||
assert ("default", 'flow2') in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
@with_async_processor_patches
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
|
||||
async def test_start_calls_parent(self, mock_parent_start, mock_register_config, mock_async_init):
|
||||
async def test_start_calls_parent(self, mock_parent_start, *mocks):
|
||||
"""Test that start() calls parent start method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
mock_parent_start.return_value = None
|
||||
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act
|
||||
|
||||
await processor.start()
|
||||
|
||||
# Assert
|
||||
mock_parent_start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_add_args_calls_parent(self, mock_register_config, mock_async_init):
|
||||
async def test_add_args_calls_parent(self):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
|
||||
with patch('trustgraph.base.async_processor.AsyncProcessor.add_args') as mock_parent_add_args:
|
||||
FlowProcessor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
40
tests/unit/test_base/test_i18n.py
Normal file
40
tests/unit/test_base/test_i18n.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from trustgraph.i18n import get_language_pack, get_translator, normalize_language
|
||||
|
||||
|
||||
def test_normalize_language_handles_regions_and_accept_language():
|
||||
assert normalize_language(None) == "en"
|
||||
assert normalize_language("") == "en"
|
||||
|
||||
assert normalize_language("es-ES") == "es"
|
||||
assert normalize_language("pt-BR") == "pt"
|
||||
assert normalize_language("zh") == "zh-cn"
|
||||
|
||||
assert normalize_language("es-ES,es;q=0.9,en;q=0.8") == "es"
|
||||
assert normalize_language("unknown") == "en"
|
||||
|
||||
|
||||
def test_language_pack_loads_from_resources():
|
||||
pack = get_language_pack("en")
|
||||
assert isinstance(pack, dict)
|
||||
|
||||
# Key should exist and map to a non-empty string.
|
||||
title = pack.get("cli.verify_system_status.title")
|
||||
assert isinstance(title, str)
|
||||
assert title.strip() != ""
|
||||
|
||||
|
||||
def test_translator_formats_placeholders():
|
||||
tr = get_translator("en")
|
||||
out = tr.t(
|
||||
"cli.verify_system_status.checking_attempt",
|
||||
name="Pulsar",
|
||||
attempt=2,
|
||||
)
|
||||
|
||||
assert "Pulsar" in out
|
||||
assert "2" in out
|
||||
|
||||
|
||||
def test_translator_falls_back_to_key_for_unknown_keys():
|
||||
tr = get_translator("en")
|
||||
assert tr.t("missing.key") == "missing.key"
|
||||
130
tests/unit/test_base/test_logging.py
Normal file
130
tests/unit/test_base/test_logging.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base.logging import add_logging_args, setup_logging
|
||||
|
||||
|
||||
def test_add_logging_args_uses_environment_defaults(monkeypatch):
|
||||
monkeypatch.setenv("LOKI_URL", "http://example.test/loki")
|
||||
monkeypatch.setenv("LOKI_USERNAME", "user")
|
||||
monkeypatch.setenv("LOKI_PASSWORD", "pass")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
add_logging_args(parser)
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert args.log_level == "INFO"
|
||||
assert args.loki_enabled is True
|
||||
assert args.loki_url == "http://example.test/loki"
|
||||
assert args.loki_username == "user"
|
||||
assert args.loki_password == "pass"
|
||||
|
||||
|
||||
def test_add_logging_args_supports_disabling_loki():
|
||||
parser = argparse.ArgumentParser()
|
||||
add_logging_args(parser)
|
||||
|
||||
args = parser.parse_args(["--no-loki-enabled"])
|
||||
|
||||
assert args.loki_enabled is False
|
||||
|
||||
|
||||
def test_setup_logging_without_loki_configures_console(monkeypatch):
|
||||
basic_config = MagicMock()
|
||||
logger = MagicMock()
|
||||
|
||||
monkeypatch.setattr(logging, "basicConfig", basic_config)
|
||||
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
|
||||
|
||||
setup_logging({"log_level": "debug", "loki_enabled": False, "id": "processor-1"})
|
||||
|
||||
kwargs = basic_config.call_args.kwargs
|
||||
assert kwargs["level"] == logging.DEBUG
|
||||
assert kwargs["force"] is True
|
||||
assert "%(processor_id)s" in kwargs["format"]
|
||||
assert len(kwargs["handlers"]) == 1
|
||||
logger.info.assert_called_once_with("Logging configured with level: debug")
|
||||
|
||||
|
||||
def test_setup_logging_with_loki_enables_queue_listener(monkeypatch):
|
||||
basic_config = MagicMock()
|
||||
root_logger = MagicMock()
|
||||
module_logger = MagicMock()
|
||||
urllib3_logger = MagicMock()
|
||||
connectionpool_logger = MagicMock()
|
||||
queue_handler = MagicMock()
|
||||
queue_listener = MagicMock()
|
||||
loki_handler = MagicMock()
|
||||
|
||||
noisy_logger = MagicMock()
|
||||
logger_map = {
|
||||
None: root_logger,
|
||||
"trustgraph.base.logging": module_logger,
|
||||
"urllib3": urllib3_logger,
|
||||
"urllib3.connectionpool": connectionpool_logger,
|
||||
"pika": noisy_logger,
|
||||
"cassandra": noisy_logger,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(logging, "basicConfig", basic_config)
|
||||
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger_map[name])
|
||||
monkeypatch.setattr(
|
||||
logging.handlers,
|
||||
"QueueHandler",
|
||||
MagicMock(return_value=queue_handler),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
logging.handlers,
|
||||
"QueueListener",
|
||||
MagicMock(return_value=queue_listener),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"logging_loki",
|
||||
SimpleNamespace(LokiHandler=MagicMock(return_value=loki_handler)),
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
{
|
||||
"log_level": "INFO",
|
||||
"loki_enabled": True,
|
||||
"loki_url": "http://loki.test/push",
|
||||
"loki_username": "user",
|
||||
"loki_password": "pass",
|
||||
"id": "processor-1",
|
||||
}
|
||||
)
|
||||
|
||||
assert root_logger.loki_queue_listener is queue_listener
|
||||
queue_listener.start.assert_called_once_with()
|
||||
urllib3_logger.setLevel.assert_called_once_with(logging.WARNING)
|
||||
connectionpool_logger.setLevel.assert_called_once_with(logging.WARNING)
|
||||
module_logger.info.assert_any_call("Logging configured with level: INFO")
|
||||
module_logger.info.assert_any_call("Loki logging enabled: http://loki.test/push")
|
||||
|
||||
|
||||
def test_setup_logging_falls_back_when_loki_module_missing(monkeypatch, capsys):
|
||||
basic_config = MagicMock()
|
||||
logger = MagicMock()
|
||||
|
||||
monkeypatch.setattr(logging, "basicConfig", basic_config)
|
||||
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
|
||||
monkeypatch.delitem(sys.modules, "logging_loki", raising=False)
|
||||
real_import = __import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "logging_loki":
|
||||
raise ImportError("missing")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.__import__", fake_import)
|
||||
|
||||
setup_logging({"log_level": "INFO", "loki_enabled": True, "id": "processor-1"})
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "python-logging-loki not installed" in output
|
||||
logger.warning.assert_called_once_with("Loki logging requested but not available")
|
||||
143
tests/unit/test_base/test_metrics.py
Normal file
143
tests/unit/test_base/test_metrics.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.base import metrics
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_metric_singletons():
|
||||
"""Temporarily remove metric singletons so each test can inject mocks.
|
||||
|
||||
Saves any existing class-level metrics and restores them after the test
|
||||
so that later tests in the same process still find the hasattr() guard
|
||||
intact — deleting without restoring causes every subsequent Processor()
|
||||
construction to re-register the same Prometheus metric name, which raises
|
||||
ValueError: Duplicated timeseries.
|
||||
"""
|
||||
classes_and_attrs = {
|
||||
metrics.ConsumerMetrics: [
|
||||
"state_metric",
|
||||
"request_metric",
|
||||
"processing_metric",
|
||||
"rate_limit_metric",
|
||||
],
|
||||
metrics.ProducerMetrics: ["producer_metric"],
|
||||
metrics.ProcessorMetrics: ["processor_metric"],
|
||||
metrics.SubscriberMetrics: [
|
||||
"state_metric",
|
||||
"received_metric",
|
||||
"dropped_metric",
|
||||
],
|
||||
}
|
||||
|
||||
saved = {}
|
||||
for cls, attrs in classes_and_attrs.items():
|
||||
for attr in attrs:
|
||||
if hasattr(cls, attr):
|
||||
saved[(cls, attr)] = getattr(cls, attr)
|
||||
delattr(cls, attr)
|
||||
|
||||
yield
|
||||
|
||||
# Remove anything the test may have set, then restore originals
|
||||
for cls, attrs in classes_and_attrs.items():
|
||||
for attr in attrs:
|
||||
if hasattr(cls, attr):
|
||||
delattr(cls, attr)
|
||||
|
||||
for (cls, attr), value in saved.items():
|
||||
setattr(cls, attr, value)
|
||||
|
||||
|
||||
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
|
||||
enum_factory = MagicMock()
|
||||
histogram_factory = MagicMock()
|
||||
counter_factory = MagicMock()
|
||||
|
||||
state_labels = MagicMock()
|
||||
request_labels = MagicMock()
|
||||
processing_labels = MagicMock()
|
||||
rate_limit_labels = MagicMock()
|
||||
timer = MagicMock()
|
||||
|
||||
enum_factory.return_value.labels.return_value = state_labels
|
||||
histogram_factory.return_value.labels.return_value = request_labels
|
||||
request_labels.time.return_value = timer
|
||||
counter_factory.side_effect = [
|
||||
MagicMock(labels=MagicMock(return_value=processing_labels)),
|
||||
MagicMock(labels=MagicMock(return_value=rate_limit_labels)),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(metrics, "Enum", enum_factory)
|
||||
monkeypatch.setattr(metrics, "Histogram", histogram_factory)
|
||||
monkeypatch.setattr(metrics, "Counter", counter_factory)
|
||||
|
||||
first = metrics.ConsumerMetrics("proc", "flow", "name")
|
||||
second = metrics.ConsumerMetrics("proc-2", "flow-2", "name-2")
|
||||
|
||||
assert enum_factory.call_count == 1
|
||||
assert histogram_factory.call_count == 1
|
||||
assert counter_factory.call_count == 2
|
||||
|
||||
first.process("ok")
|
||||
first.rate_limit()
|
||||
first.state("running")
|
||||
assert first.record_time() is timer
|
||||
|
||||
processing_labels.inc.assert_called_once_with()
|
||||
rate_limit_labels.inc.assert_called_once_with()
|
||||
state_labels.state.assert_called_once_with("running")
|
||||
|
||||
|
||||
def test_producer_metrics_increments_counter_once(monkeypatch):
|
||||
counter_factory = MagicMock()
|
||||
labels = MagicMock()
|
||||
counter_factory.return_value.labels.return_value = labels
|
||||
monkeypatch.setattr(metrics, "Counter", counter_factory)
|
||||
|
||||
producer_metrics = metrics.ProducerMetrics("proc", "flow", "output")
|
||||
producer_metrics.inc()
|
||||
|
||||
counter_factory.assert_called_once()
|
||||
labels.inc.assert_called_once_with()
|
||||
|
||||
|
||||
def test_processor_metrics_reports_info(monkeypatch):
|
||||
info_factory = MagicMock()
|
||||
labels = MagicMock()
|
||||
info_factory.return_value.labels.return_value = labels
|
||||
monkeypatch.setattr(metrics, "Info", info_factory)
|
||||
|
||||
processor_metrics = metrics.ProcessorMetrics("proc")
|
||||
processor_metrics.info({"kind": "test"})
|
||||
|
||||
info_factory.assert_called_once()
|
||||
labels.info.assert_called_once_with({"kind": "test"})
|
||||
|
||||
|
||||
def test_subscriber_metrics_tracks_received_state_and_dropped(monkeypatch):
|
||||
enum_factory = MagicMock()
|
||||
counter_factory = MagicMock()
|
||||
|
||||
state_labels = MagicMock()
|
||||
received_labels = MagicMock()
|
||||
dropped_labels = MagicMock()
|
||||
|
||||
enum_factory.return_value.labels.return_value = state_labels
|
||||
counter_factory.side_effect = [
|
||||
MagicMock(labels=MagicMock(return_value=received_labels)),
|
||||
MagicMock(labels=MagicMock(return_value=dropped_labels)),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(metrics, "Enum", enum_factory)
|
||||
monkeypatch.setattr(metrics, "Counter", counter_factory)
|
||||
|
||||
subscriber_metrics = metrics.SubscriberMetrics("proc", "flow", "input")
|
||||
subscriber_metrics.received()
|
||||
subscriber_metrics.state("running")
|
||||
subscriber_metrics.dropped("ignored")
|
||||
|
||||
received_labels.inc.assert_called_once_with()
|
||||
dropped_labels.inc.assert_called_once_with()
|
||||
state_labels.state.assert_called_once_with("running")
|
||||
|
|
@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
|
|||
with patch.object(subscriber, 'run') as mock_run:
|
||||
# Mock run that simulates graceful shutdown
|
||||
async def mock_run_graceful():
|
||||
# Honor the readiness contract: real run() signals _ready
|
||||
# after binding the consumer, so start() can unblock. Mocks
|
||||
# of run() must do the same or start() hangs forever.
|
||||
subscriber._ready.set_result(None)
|
||||
# Process messages while running, then drain
|
||||
while subscriber.running or subscriber.draining:
|
||||
if subscriber.draining:
|
||||
|
|
@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
|
|||
with patch.object(subscriber, 'run') as mock_run:
|
||||
# Mock run that simulates cleanup of pending acks
|
||||
async def mock_run_cleanup():
|
||||
# Honor the readiness contract — see test_subscriber_graceful_shutdown.
|
||||
subscriber._ready.set_result(None)
|
||||
while subscriber.running or subscriber.draining:
|
||||
await asyncio.sleep(0.05)
|
||||
if subscriber.draining:
|
||||
|
|
@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
|
|||
msg1 = await queue1.get()
|
||||
msg_all = await queue_all.get()
|
||||
assert msg1 == {"data": "broadcast"}
|
||||
assert msg_all == {"data": "broadcast"}
|
||||
assert msg_all == {"data": "broadcast"}
|
||||
|
|
|
|||
189
tests/unit/test_base/test_subscriber_readiness.py
Normal file
189
tests/unit/test_base/test_subscriber_readiness.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Regression tests for Subscriber.start() readiness barrier.
|
||||
|
||||
Background: prior to the eager-connect fix, Subscriber.start() created
|
||||
the run() task and returned immediately. The underlying backend consumer
|
||||
was lazily connected on its first receive() call, which left a setup
|
||||
race for request/response clients using ephemeral per-subscriber response
|
||||
queues (RabbitMQ auto-delete exclusive queues): the request would be
|
||||
published before the response queue was bound, and the broker would
|
||||
silently drop the reply. fetch_config(), document-embeddings, and
|
||||
api-gateway all hit this with "Failed to fetch config on notify" /
|
||||
"Request timeout exception" symptoms.
|
||||
|
||||
These tests pin the readiness contract:
|
||||
|
||||
await subscriber.start()
|
||||
# at this point, consumer.ensure_connected() MUST have run
|
||||
|
||||
so that any future change which removes the eager bind, or moves it
|
||||
back to lazy initialisation, fails CI loudly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base.subscriber import Subscriber
|
||||
|
||||
|
||||
def _make_backend(ensure_connected_side_effect=None,
|
||||
receive_side_effect=None):
|
||||
"""Build a fake backend whose consumer records ensure_connected /
|
||||
receive calls. ensure_connected_side_effect lets a test inject a
|
||||
delay or exception."""
|
||||
backend = MagicMock()
|
||||
consumer = MagicMock()
|
||||
|
||||
consumer.ensure_connected = MagicMock(
|
||||
side_effect=ensure_connected_side_effect,
|
||||
)
|
||||
|
||||
# By default receive raises a timeout-style exception that the
|
||||
# subscriber loop is supposed to swallow as a "no message yet" — this
|
||||
# keeps the subscriber idling cleanly while the test inspects state.
|
||||
if receive_side_effect is None:
|
||||
receive_side_effect = TimeoutError("No message received within timeout")
|
||||
consumer.receive = MagicMock(side_effect=receive_side_effect)
|
||||
|
||||
consumer.acknowledge = MagicMock()
|
||||
consumer.negative_acknowledge = MagicMock()
|
||||
consumer.pause_message_listener = MagicMock()
|
||||
consumer.unsubscribe = MagicMock()
|
||||
consumer.close = MagicMock()
|
||||
|
||||
backend.create_consumer.return_value = consumer
|
||||
return backend, consumer
|
||||
|
||||
|
||||
def _make_subscriber(backend):
|
||||
return Subscriber(
|
||||
backend=backend,
|
||||
topic="response:tg:config",
|
||||
subscription="test-sub",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
drain_timeout=1.0,
|
||||
backpressure_strategy="block",
|
||||
)
|
||||
|
||||
|
||||
class TestSubscriberReadiness:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_ensure_connected_before_returning(self):
|
||||
"""The barrier: ensure_connected must have been invoked at least
|
||||
once by the time start() returns."""
|
||||
backend, consumer = _make_backend()
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
await subscriber.start()
|
||||
|
||||
try:
|
||||
consumer.ensure_connected.assert_called_once()
|
||||
finally:
|
||||
await subscriber.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_blocks_until_ensure_connected_completes(self):
|
||||
"""If ensure_connected is slow, start() must wait for it. This is
|
||||
the actual race-condition guard — it would have failed against
|
||||
the buggy version where start() returned before run() had even
|
||||
scheduled the consumer creation."""
|
||||
connect_started = asyncio.Event()
|
||||
release_connect = asyncio.Event()
|
||||
|
||||
# ensure_connected runs in the executor thread, so we need a
|
||||
# threading-safe gate. Use a simple busy-wait on a flag set by
|
||||
# the asyncio side via call_soon_threadsafe — but the simpler
|
||||
# path is to give it a sleep and observe ordering.
|
||||
import threading
|
||||
gate = threading.Event()
|
||||
|
||||
def slow_connect():
|
||||
connect_started.set() # safe: only mutates the Event flag
|
||||
gate.wait(timeout=2.0)
|
||||
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=slow_connect,
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
start_task = asyncio.create_task(subscriber.start())
|
||||
|
||||
# Wait until ensure_connected has begun executing.
|
||||
await asyncio.wait_for(connect_started.wait(), timeout=2.0)
|
||||
|
||||
# ensure_connected is in flight — start() must NOT have returned.
|
||||
assert not start_task.done(), (
|
||||
"start() returned before ensure_connected() completed — "
|
||||
"the readiness barrier is broken and the request/response "
|
||||
"race condition is back."
|
||||
)
|
||||
|
||||
# Release the gate; start() should now complete promptly.
|
||||
gate.set()
|
||||
await asyncio.wait_for(start_task, timeout=2.0)
|
||||
|
||||
consumer.ensure_connected.assert_called_once()
|
||||
|
||||
await subscriber.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_propagates_consumer_creation_failure(self):
|
||||
"""If create_consumer() raises, start() must surface the error
|
||||
rather than hang on the readiness future. The old code path
|
||||
retried indefinitely inside run() and never let start() unblock."""
|
||||
backend = MagicMock()
|
||||
backend.create_consumer.side_effect = RuntimeError("broker down")
|
||||
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
with pytest.raises(RuntimeError, match="broker down"):
|
||||
await asyncio.wait_for(subscriber.start(), timeout=2.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_propagates_ensure_connected_failure(self):
|
||||
"""Same contract for an ensure_connected() that raises (e.g. the
|
||||
broker is up but the queue declare/bind fails)."""
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=RuntimeError("queue declare failed"),
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
with pytest.raises(RuntimeError, match="queue declare failed"):
|
||||
await asyncio.wait_for(subscriber.start(), timeout=2.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_connected_runs_before_subscriber_running_log(self):
|
||||
"""Subtle ordering: ensure_connected MUST happen before the
|
||||
receive loop, so that any reply is captured. We assert this by
|
||||
checking ensure_connected was called before any receive call."""
|
||||
call_order = []
|
||||
|
||||
def record_ensure():
|
||||
call_order.append("ensure_connected")
|
||||
|
||||
def record_receive(*args, **kwargs):
|
||||
call_order.append("receive")
|
||||
raise TimeoutError("No message received within timeout")
|
||||
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=record_ensure,
|
||||
receive_side_effect=record_receive,
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
await subscriber.start()
|
||||
|
||||
# Give the receive loop a tick to run at least once.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
await subscriber.stop()
|
||||
|
||||
# ensure_connected must come first; receive may not have happened
|
||||
# yet on a fast machine, but if it did, it must come after.
|
||||
assert call_order, "neither ensure_connected nor receive was called"
|
||||
assert call_order[0] == "ensure_connected"
|
||||
|
|
@ -28,7 +28,6 @@ def sample_text_document():
|
|||
"""Sample document with moderate length text."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "The quick brown fox jumps over the lazy dog. " * 20
|
||||
|
|
@ -43,7 +42,6 @@ def long_text_document():
|
|||
"""Long document for testing multiple chunks."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-long",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
# Create a long text that will definitely be chunked
|
||||
|
|
@ -59,7 +57,6 @@ def unicode_text_document():
|
|||
"""Document with various unicode characters."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-unicode",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = """
|
||||
|
|
@ -84,7 +81,6 @@ def empty_text_document():
|
|||
"""Empty document for edge case testing."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-empty",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
return TextDocument(
|
||||
|
|
|
|||
|
|
@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
# Flow exposes parameter lookup via __call__: flow("chunk-size")
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -184,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
|
|
@ -195,15 +195,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
# Flow.__call__ resolves parameters and producers/consumers from the
|
||||
# same dict — merge both kinds here.
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(name)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -241,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
mock_flow.side_effect = lambda key: None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
# Flow exposes parameter lookup via __call__: flow("chunk-size")
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -184,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
|
|
@ -195,15 +195,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
# Flow.__call__ resolves parameters and producers/consumers from the
|
||||
# same dict — merge both kinds here.
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(name)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -245,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
mock_flow.side_effect = lambda key: None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ class TestListConfigItems:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
|
|
@ -128,7 +129,8 @@ class TestListConfigItems:
|
|||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -196,7 +198,8 @@ class TestGetConfigItem:
|
|||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -253,7 +256,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
|
|
@ -278,7 +282,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
|
|
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def knowledge_loader():
|
|||
return KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc-123",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
|
|||
loader = KnowledgeLoader(
|
||||
files=["file1.ttl", "file2.ttl"],
|
||||
flow="my-flow",
|
||||
user="user1",
|
||||
workspace="user1",
|
||||
collection="col1",
|
||||
document_id="doc1",
|
||||
url="http://example.com/",
|
||||
|
|
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
|
|||
|
||||
assert loader.files == ["file1.ttl", "file2.ttl"]
|
||||
assert loader.flow == "my-flow"
|
||||
assert loader.user == "user1"
|
||||
assert loader.workspace == "user1"
|
||||
assert loader.collection == "col1"
|
||||
assert loader.document_id == "doc1"
|
||||
assert loader.url == "http://example.com/"
|
||||
|
|
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
|
|||
# Verify Api was created with correct parameters
|
||||
mock_api_class.assert_called_once_with(
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
token="test-token",
|
||||
workspace="test-user"
|
||||
)
|
||||
|
||||
# Verify bulk client was obtained
|
||||
|
|
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
|
|||
call_args = mock_bulk.import_triples.call_args
|
||||
assert call_args[1]['flow'] == "test-flow"
|
||||
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||
assert call_args[1]['metadata']['user'] == "test-user"
|
||||
assert call_args[1]['metadata']['collection'] == "test-collection"
|
||||
|
||||
# Verify import_entity_contexts was called
|
||||
|
|
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
|
|||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'-f', 'my-flow',
|
||||
'-U', 'my-user',
|
||||
'-w', 'my-user',
|
||||
'-C', 'my-collection',
|
||||
'-u', 'http://custom.example.com/',
|
||||
'-t', 'my-token',
|
||||
|
|
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
|
|||
token='my-token',
|
||||
flow='my-flow',
|
||||
files=['file1.ttl', 'file2.ttl'],
|
||||
user='my-user',
|
||||
workspace='my-user',
|
||||
collection='my-collection'
|
||||
)
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
|
|||
# Verify defaults were used
|
||||
call_args = mock_loader_class.call_args[1]
|
||||
assert call_args['flow'] == 'default'
|
||||
assert call_args['user'] == 'trustgraph'
|
||||
assert call_args['workspace'] == 'default'
|
||||
assert call_args['collection'] == 'default'
|
||||
assert call_args['url'] == 'http://localhost:8088/'
|
||||
assert call_args['token'] is None
|
||||
|
|
@ -287,7 +287,7 @@ class TestErrorHandling:
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_set_main_structured_query_no_arguments_needed(self):
|
||||
|
|
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
|
|
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
show_main()
|
||||
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Act
|
||||
result = client.request(
|
||||
vector=vector,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
limit=10,
|
||||
timeout=300
|
||||
|
|
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.call.assert_called_once_with(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["test_chunk"]
|
||||
client.call.assert_called_once_with(
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def _make_query(
|
|||
|
||||
query = Query(
|
||||
rag=rag,
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
verbose=False,
|
||||
entity_limit=entity_limit,
|
||||
|
|
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
|
|||
assert calls[0].kwargs["p"] is None
|
||||
assert calls[0].kwargs["o"] is None
|
||||
assert calls[0].kwargs["limit"] == 15
|
||||
assert calls[0].kwargs["user"] == "test-user"
|
||||
assert calls[0].kwargs["collection"] == "test-collection"
|
||||
assert calls[0].kwargs["batch_size"] == 20
|
||||
|
||||
|
|
|
|||
|
|
@ -28,10 +28,12 @@ def mock_flow_config():
|
|||
"""Mock flow configuration."""
|
||||
mock_config = Mock()
|
||||
mock_config.flows = {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": "test-triples-queue",
|
||||
"graph-embeddings-store": "test-ge-queue"
|
||||
"test-user": {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -43,7 +45,7 @@ def mock_flow_config():
|
|||
def mock_request():
|
||||
"""Mock knowledge load request."""
|
||||
request = Mock()
|
||||
request.user = "test-user"
|
||||
request.workspace = "test-user"
|
||||
request.id = "test-doc-id"
|
||||
request.collection = "test-collection"
|
||||
request.flow = "test-flow"
|
||||
|
|
@ -71,7 +73,6 @@ def sample_triples():
|
|||
return Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -90,7 +91,6 @@ def sample_graph_embeddings():
|
|||
return GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_triples_pub.send.assert_called_once()
|
||||
sent_triples = mock_triples_pub.send.call_args[0][1]
|
||||
assert sent_triples.metadata.collection == "test-collection"
|
||||
assert sent_triples.metadata.user == "test-user"
|
||||
assert sent_triples.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_ge_pub.send.assert_called_once()
|
||||
sent_ge = mock_ge_pub.send.call_args[0][1]
|
||||
assert sent_ge.metadata.collection == "test-collection"
|
||||
assert sent_ge.metadata.user == "test-user"
|
||||
assert sent_ge.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
|
||||
# Create request with None collection
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = None # Should fall back to "default"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core validates flow configuration before processing."""
|
||||
# Request with invalid flow
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
|
||||
|
|
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
|
||||
# Test missing ID
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = None # Missing
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
|
||||
"""Test that get_kg_core preserves collection field from stored data."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_list_kg_cores(self, knowledge_manager):
|
||||
"""Test listing knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
||||
|
|
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_delete_kg_core(self, knowledge_manager):
|
||||
"""Test deleting knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message with inline data
|
||||
content = b"# Document Title\nBody text content."
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
]
|
||||
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_basic(self):
|
||||
"""Test basic collection name creation"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_special_characters(self):
|
||||
"""Test collection name creation with special characters that need sanitization"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@domain.com",
|
||||
workspace="user@domain.com",
|
||||
collection="test-collection.v2",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_unicode(self):
|
||||
"""Test collection name creation with Unicode characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="测试用户",
|
||||
workspace="测试用户",
|
||||
collection="colección_española",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_spaces(self):
|
||||
"""Test collection name creation with spaces"""
|
||||
result = make_safe_collection_name(
|
||||
user="test user",
|
||||
workspace="test user",
|
||||
collection="my test collection",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
|
||||
"""Test collection name creation with multiple consecutive special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@@@domain!!!",
|
||||
workspace="user@@@domain!!!",
|
||||
collection="test---collection...v2",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
|
||||
"""Test collection name creation with leading/trailing special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="__test_user__",
|
||||
workspace="__test_user__",
|
||||
collection="@@test_collection##",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_user(self):
|
||||
"""Test collection name creation with empty user (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_collection(self):
|
||||
"""Test collection name creation with empty collection (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_both_empty(self):
|
||||
"""Test collection name creation with both user and collection empty"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_only_special_characters(self):
|
||||
"""Test collection name creation with only special characters (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="@@@!!!",
|
||||
workspace="@@@!!!",
|
||||
collection="---###",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_whitespace_only(self):
|
||||
"""Test collection name creation with whitespace-only strings"""
|
||||
result = make_safe_collection_name(
|
||||
user=" \n\t ",
|
||||
workspace=" \n\t ",
|
||||
collection=" \r\n ",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
|
||||
"""Test collection name creation with mixed valid and invalid characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123@test",
|
||||
workspace="user123@test",
|
||||
collection="coll_2023.v1",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
|
|||
long_collection = "b" * 100
|
||||
|
||||
result = make_safe_collection_name(
|
||||
user=long_user,
|
||||
workspace=long_user,
|
||||
collection=long_collection,
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_numeric_values(self):
|
||||
"""Test collection name creation with numeric user/collection values"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123",
|
||||
workspace="user123",
|
||||
collection="collection456",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_case_sensitivity(self):
|
||||
"""Test that collection name creation preserves case"""
|
||||
result = make_safe_collection_name(
|
||||
user="TestUser",
|
||||
workspace="TestUser",
|
||||
collection="TestCollection",
|
||||
prefix="Doc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ def processor():
|
|||
)
|
||||
|
||||
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
|
||||
user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self, processor):
|
||||
"""Output should carry the original metadata."""
|
||||
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
|
||||
msg = _make_chunk_message(collection="reports", doc_id="d1")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
|
|
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
|
|||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "reports"
|
||||
assert result.metadata.id == "d1"
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
|
|||
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
|
||||
|
||||
|
||||
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_message(entities, doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = EntityContexts(metadata=metadata, entities=entities)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
|
||||
msg = _make_message(entities, doc_id="doc-42", collection="main")
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
|
||||
mock_output = AsyncMock()
|
||||
|
|
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
for call in mock_output.send.call_args_list:
|
||||
result = call[0][0]
|
||||
assert result.metadata.id == "doc-42"
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
assert 'customers' in processor.schemas["default"]
|
||||
assert processor.schemas["default"]['customers'].name == 'customers'
|
||||
assert len(processor.schemas["default"]['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
|
|
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
assert processor.schemas.get("default", {}) == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
|
|
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
|
|
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
'customers': RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
|
|
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
mock_flow.workspace = "default"
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,200 @@
|
|||
"""
|
||||
Unit tests for extract_with_simplified_format.
|
||||
|
||||
Regression guard for the bug where the extractor read
|
||||
``result.object`` (singular, used for response_type="json") instead of
|
||||
``result.objects`` (plural, used for response_type="jsonl"). The
|
||||
extract-with-ontologies prompt is JSONL, so reading the wrong field
|
||||
silently dropped every extraction and left the knowledge graph
|
||||
populated only by ontology schema + document provenance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor():
|
||||
"""Create a Processor instance without running its heavy __init__.
|
||||
|
||||
Matches the pattern used in test_prompt_and_extraction.py: only
|
||||
the attributes the code under test touches need to be set.
|
||||
"""
|
||||
ex = object.__new__(Processor)
|
||||
ex.URI_PREFIXES = {
|
||||
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
|
||||
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
|
||||
"owl:": "http://www.w3.org/2002/07/owl#",
|
||||
"xsd:": "http://www.w3.org/2001/XMLSchema#",
|
||||
}
|
||||
return ex
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def food_subset():
|
||||
"""A minimal food ontology subset the extracted entities reference."""
|
||||
return OntologySubset(
|
||||
ontology_id="food",
|
||||
classes={
|
||||
"Recipe": {
|
||||
"uri": "http://purl.org/ontology/fo/Recipe",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Recipe", "lang": "en-gb"}],
|
||||
"comment": "A Recipe.",
|
||||
},
|
||||
"Food": {
|
||||
"uri": "http://purl.org/ontology/fo/Food",
|
||||
"type": "owl:Class",
|
||||
"labels": [{"value": "Food", "lang": "en-gb"}],
|
||||
"comment": "A Food.",
|
||||
},
|
||||
},
|
||||
object_properties={
|
||||
"ingredients": {
|
||||
"uri": "http://purl.org/ontology/fo/ingredients",
|
||||
"type": "owl:ObjectProperty",
|
||||
"labels": [{"value": "ingredients", "lang": "en-gb"}],
|
||||
"comment": "Relates a recipe to its ingredients.",
|
||||
"domain": "Recipe",
|
||||
"range": "Food",
|
||||
},
|
||||
},
|
||||
datatype_properties={},
|
||||
metadata={
|
||||
"name": "Food Ontology",
|
||||
"namespace": "http://purl.org/ontology/fo/",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _flow_with_prompt_result(prompt_result):
|
||||
"""Build the ``flow(name)`` callable the extractor invokes.
|
||||
|
||||
``extract_with_simplified_format`` calls
|
||||
``flow("prompt-request").prompt(...)`` — so we need ``flow`` to be
|
||||
callable, return an object whose ``.prompt`` is an AsyncMock that
|
||||
resolves to ``prompt_result``.
|
||||
"""
|
||||
prompt_service = MagicMock()
|
||||
prompt_service.prompt = AsyncMock(return_value=prompt_result)
|
||||
|
||||
def flow(name):
|
||||
assert name == "prompt-request", (
|
||||
f"extractor should only invoke flow('prompt-request'), "
|
||||
f"got {name!r}"
|
||||
)
|
||||
return prompt_service
|
||||
|
||||
return flow, prompt_service.prompt
|
||||
|
||||
|
||||
class TestReadsObjectsForJsonlPrompt:
|
||||
"""extract-with-ontologies is a JSONL prompt; the extractor must
|
||||
read ``result.objects``, not ``result.object``."""
|
||||
|
||||
async def test_populated_objects_produces_triples(
|
||||
self, extractor, food_subset,
|
||||
):
|
||||
"""Happy path: PromptResult with populated .objects -> non-empty
|
||||
triples list."""
|
||||
|
||||
prompt_result = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"type": "entity", "entity": "Cornish Pasty",
|
||||
"entity_type": "Recipe"},
|
||||
{"type": "entity", "entity": "beef",
|
||||
"entity_type": "Food"},
|
||||
{"type": "relationship",
|
||||
"subject": "Cornish Pasty", "subject_type": "Recipe",
|
||||
"relation": "ingredients",
|
||||
"object": "beef", "object_type": "Food"},
|
||||
],
|
||||
)
|
||||
|
||||
flow, prompt_mock = _flow_with_prompt_result(prompt_result)
|
||||
|
||||
triples = await extractor.extract_with_simplified_format(
|
||||
flow, "some chunk", food_subset, {"text": "some chunk"},
|
||||
)
|
||||
|
||||
prompt_mock.assert_awaited_once()
|
||||
assert triples, (
|
||||
"extract_with_simplified_format returned no triples; if "
|
||||
"this fails, the extractor is probably reading .object "
|
||||
"instead of .objects again"
|
||||
)
|
||||
|
||||
async def test_none_objects_returns_empty_without_crashing(
|
||||
self, extractor, food_subset,
|
||||
):
|
||||
"""The exact shape that hit production on v2.3: the extractor
|
||||
was reading ``.object`` for a JSONL prompt, which returned
|
||||
``None`` and tripped the parser's 'Unexpected response type'
|
||||
path. With the fix we read ``.objects``; if that's also
|
||||
``None`` we must still return ``[]`` cleanly, not crash."""
|
||||
|
||||
prompt_result = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=None,
|
||||
)
|
||||
|
||||
flow, _ = _flow_with_prompt_result(prompt_result)
|
||||
|
||||
triples = await extractor.extract_with_simplified_format(
|
||||
flow, "chunk", food_subset, {"text": "chunk"},
|
||||
)
|
||||
|
||||
assert triples == []
|
||||
|
||||
async def test_empty_objects_returns_empty(
|
||||
self, extractor, food_subset,
|
||||
):
|
||||
"""Valid JSONL response with zero entries should yield zero
|
||||
triples, not raise."""
|
||||
|
||||
prompt_result = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[],
|
||||
)
|
||||
|
||||
flow, _ = _flow_with_prompt_result(prompt_result)
|
||||
|
||||
triples = await extractor.extract_with_simplified_format(
|
||||
flow, "chunk", food_subset, {"text": "chunk"},
|
||||
)
|
||||
|
||||
assert triples == []
|
||||
|
||||
async def test_ignores_object_field_for_jsonl_prompt(
|
||||
self, extractor, food_subset,
|
||||
):
|
||||
"""If ``.object`` is somehow set but ``.objects`` is None, the
|
||||
extractor must not silently fall back to ``.object``. This
|
||||
guards against a well-meaning regression that "helpfully"
|
||||
re-adds fallback fields.
|
||||
|
||||
The extractor should read only ``.objects`` for this prompt;
|
||||
when that is None we expect the empty-result path.
|
||||
"""
|
||||
|
||||
prompt_result = PromptResult(
|
||||
response_type="json",
|
||||
object={"not": "the field we should be reading"},
|
||||
objects=None,
|
||||
)
|
||||
|
||||
flow, _ = _flow_with_prompt_result(prompt_result)
|
||||
|
||||
triples = await extractor.extract_with_simplified_format(
|
||||
flow, "chunk", food_subset, {"text": "chunk"},
|
||||
)
|
||||
|
||||
assert triples == [], (
|
||||
"Extractor fell back to .object for a JSONL prompt — "
|
||||
"this is the regression shape we are trying to prevent"
|
||||
)
|
||||
|
|
@ -231,6 +231,52 @@ class TestTripleValidation:
|
|||
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
|
||||
assert is_valid == expected, f"Validation of {predicate} should be {expected}"
|
||||
|
||||
def test_validates_domain_correctly_with_entity_types(self, extractor, sample_ontology_subset):
|
||||
"""Test domain validation correctly compares against extracted entity_types."""
|
||||
subject = "my-recipe"
|
||||
predicate = "produces"
|
||||
object_val = "my-food"
|
||||
|
||||
# Proper domain for produces is Recipe
|
||||
entity_types = {
|
||||
"my-recipe": "Recipe",
|
||||
"my-food": "Food"
|
||||
}
|
||||
|
||||
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
|
||||
assert is_valid, "Valid domain should be accepted"
|
||||
|
||||
# Invalid domain
|
||||
entity_types_invalid = {
|
||||
"my-recipe": "Ingredient",
|
||||
"my-food": "Food"
|
||||
}
|
||||
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
|
||||
assert not is_invalid, "Invalid domain should be rejected"
|
||||
|
||||
def test_validates_range_correctly_with_entity_types(self, extractor, sample_ontology_subset):
|
||||
"""Test range validation correctly compares against extracted entity_types."""
|
||||
subject = "my-recipe"
|
||||
predicate = "produces"
|
||||
object_val = "my-food"
|
||||
|
||||
# Proper range for produces is Food
|
||||
entity_types = {
|
||||
"my-recipe": "Recipe",
|
||||
"my-food": "Food"
|
||||
}
|
||||
|
||||
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
|
||||
assert is_valid, "Valid range should be accepted"
|
||||
|
||||
# Invalid range
|
||||
entity_types_invalid = {
|
||||
"my-recipe": "Recipe",
|
||||
"my-food": "Recipe"
|
||||
}
|
||||
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
|
||||
assert not is_invalid, "Invalid range should be rejected"
|
||||
|
||||
|
||||
class TestTripleParsing:
|
||||
"""Test suite for parsing triples from LLM responses."""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
|
@ -33,11 +34,10 @@ def _make_defn(entity, definition):
|
|||
return {"entity": entity, "definition": definition}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -51,8 +51,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
if isinstance(prompt_result, list):
|
||||
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
|
||||
else:
|
||||
wrapped = PromptResult(response_type="text", text=prompt_result)
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=wrapped
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
@ -224,8 +228,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -233,7 +236,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(triples_pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -242,8 +244,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-2", root="r-2",
|
||||
user="u-2", collection="coll-2",
|
||||
"text", meta_id="c-2", root="r-2", collection="coll-2",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
|
|||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -37,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
|
|||
}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
"""Build a mock message wrapping a Chunk."""
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -58,7 +58,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=prompt_result,
|
||||
)
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
@ -185,8 +188,7 @@ class TestMetadataPreservation:
|
|||
rels = [_make_rel("X", "rel", "Y")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -194,7 +196,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
|
|||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
def _notify(version, changes):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
|
|
@ -47,98 +53,70 @@ class TestConfigReceiver:
|
|||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
async def test_on_config_notify_new_version_fetches_per_workspace(self):
|
||||
"""Notify with newer version fetches each affected workspace."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
msg = _notify(2, {"flow": ["ws1", "ws2"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert set(fetch_calls) == {"ws1", "ws2"}
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
"""Older-version notifies are ignored."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(3, {"flow": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
"""Notifies without flow changes advance version but skip fetch."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(2, {"prompt": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow", "active-flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
"""on_config_notify swallows exceptions from message decode."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
|
|
@ -146,19 +124,18 @@ class TestConfigReceiver:
|
|||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
async def test_fetch_and_apply_workspace_starts_new_flows(self):
|
||||
"""fetch_and_apply_workspace starts newly-configured flows."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,36 +144,39 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" in config_receiver.flows["default"]
|
||||
assert len(start_flow_calls) == 2
|
||||
assert all(c[0] == "default" for c in start_flow_calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
|
||||
"""fetch_and_apply_workspace stops flows no longer configured."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,20 +185,22 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" not in config_receiver.flows["default"]
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
assert stop_flow_calls[0][:2] == ("default", "flow2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
async def test_fetch_and_apply_workspace_with_no_flows(self):
|
||||
"""Empty workspace config clears any local flow state."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
|
@ -231,88 +213,100 @@ class TestConfigReceiver:
|
|||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.flows.get("default", {}) == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
"""start_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler1.start_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
handler2.start_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in start_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
"""stop_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler1.stop_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
handler2.stop_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in stop_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -329,25 +323,25 @@ class TestConfigReceiver:
|
|||
mock_create_task.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
|
||||
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
"flow3": '{"name": "test_flow_3"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -358,20 +352,22 @@ class TestConfigReceiver:
|
|||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_calls.append((workspace, id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
ws_flows = config_receiver.flows["default"]
|
||||
assert "flow1" not in ws_flows
|
||||
assert "flow2" in ws_flows
|
||||
assert "flow3" in ws_flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert start_calls[0][:2] == ("default", "flow3")
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
assert stop_calls[0][:2] == ("default", "flow1")
|
||||
|
|
|
|||
406
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
406
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
|
|
@ -0,0 +1,406 @@
|
|||
"""
|
||||
Round-trip unit tests for the core msgpack import/export gateway endpoints.
|
||||
|
||||
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
|
||||
the responder callback and packs them into msgpack tuples. The kg-core
|
||||
import endpoint takes msgpack tuples back off the wire and rebuilds
|
||||
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
|
||||
(whose translator decodes them into real dataclasses).
|
||||
|
||||
Regression coverage: the previous wire format used `"vectors"` (plural)
|
||||
in the entity blobs and embedded a stale `"m"` field that referenced the
|
||||
removed `Metadata.metadata` triples-list field. The export side hit a
|
||||
KeyError on first message; the import side built dicts that the
|
||||
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
|
||||
tests pin both halves of the wire protocol.
|
||||
"""
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.core_export import CoreExport
|
||||
from trustgraph.gateway.dispatch.core_import import CoreImport
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
|
||||
# would emit). The vector wire key is *singular* on purpose; the export
|
||||
# side previously read the wrong key and crashed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ge_response_dict():
|
||||
return {
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/alice"},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/bob"},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _triples_response_dict():
|
||||
return {
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", workspace="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "workspace": workspace}
|
||||
return request
|
||||
|
||||
|
||||
def _make_data_reader(payload: bytes):
|
||||
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
|
||||
chunks = [payload, b""]
|
||||
|
||||
data = Mock()
|
||||
|
||||
async def fake_read(n):
|
||||
return chunks.pop(0) if chunks else b""
|
||||
|
||||
data.read = fake_read
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export side: translator-shaped dict -> msgpack bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreExportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_graph_embeddings_with_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The export side must read `ent["vector"]` and emit `v`. The
|
||||
previous bug was reading `ent["vectors"]` which KeyErrored against
|
||||
the translator output."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# Did not raise, did not call error()
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
|
||||
assert len(items) == 1
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
|
||||
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
|
||||
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_triples_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(), error=error, ok=ok, request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
assert len(items) == 1
|
||||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import side: msgpack bytes -> translator-shaped dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_graph_embeddings_to_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The import side must build dicts whose entity blobs have the
|
||||
singular `vector` key — that's what the KnowledgeRequestTranslator
|
||||
decode side reads. Previous bug emitted `vectors`."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
# Build a msgpack tuple matching the new wire format
|
||||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
"v": [0.1, 0.2, 0.3],
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["workspace"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# Metadata envelope must NOT contain a stale `metadata` key
|
||||
# referencing the removed Metadata.metadata field.
|
||||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
# Entity blob carries the singular `vector` key
|
||||
assert len(ge["entities"]) == 1
|
||||
ent = ge["entities"][0]
|
||||
assert ent["vector"] == [0.1, 0.2, 0.3]
|
||||
assert "vectors" not in ent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
triples = req["triples"]
|
||||
assert "metadata" not in triples["metadata"] # no stale field
|
||||
assert len(triples["triples"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full round-trip: export bytes feed directly into import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportExportRoundTrip:
|
||||
"""End-to-end: produce bytes via core_export, consume them via
|
||||
core_import, and verify the dict that lands at the import-side
|
||||
translator is structurally equivalent to what went in. This is the
|
||||
test that catches asymmetries between the two halves."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_graph_embeddings_round_trip(
|
||||
self, mock_export_kr_class, mock_import_kr_class,
|
||||
):
|
||||
# ----- export side: capture bytes -----
|
||||
export_bytes = []
|
||||
|
||||
async def fake_export_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
export_kr = AsyncMock()
|
||||
export_kr.start = AsyncMock()
|
||||
export_kr.stop = AsyncMock()
|
||||
export_kr.process = fake_export_process
|
||||
mock_export_kr_class.return_value = export_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
export_bytes.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=response),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
assert len(export_bytes) == 1
|
||||
|
||||
# ----- import side: feed those bytes back in -----
|
||||
import_captured = []
|
||||
|
||||
async def fake_import_process(req_dict):
|
||||
import_captured.append(req_dict)
|
||||
|
||||
import_kr = AsyncMock()
|
||||
import_kr.start = AsyncMock()
|
||||
import_kr.stop = AsyncMock()
|
||||
import_kr.process = fake_import_process
|
||||
mock_import_kr_class.return_value = import_kr
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(export_bytes[0]),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# ----- verify the dict the importer would hand to the translator -----
|
||||
assert len(import_captured) == 1
|
||||
req = import_captured[0]
|
||||
|
||||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
assert got["entity"] == want["entity"]
|
||||
|
|
@ -72,10 +72,10 @@ class TestDispatcherManager:
|
|||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
await manager.start_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") in manager.flows
|
||||
assert manager.flows[("default", "flow1")] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
|
|
@ -86,11 +86,11 @@ class TestDispatcherManager:
|
|||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
manager.flows[("default", "flow1")] = flow_data
|
||||
|
||||
await manager.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
|
|
@ -275,12 +275,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -290,7 +290,7 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ class TestDispatcherManager:
|
|||
backend=mock_backend,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"}
|
||||
queue="test_queue"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
assert result == mock_dispatcher
|
||||
|
|
@ -326,12 +326,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
|
|
@ -348,12 +348,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -370,7 +370,7 @@ class TestDispatcherManager:
|
|||
backend=mock_backend,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"},
|
||||
queue="test_queue",
|
||||
consumer="api-gateway-test-uuid",
|
||||
subscriber="api-gateway-test-uuid"
|
||||
)
|
||||
|
|
@ -404,7 +404,7 @@ class TestDispatcherManager:
|
|||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -415,14 +415,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
|
@ -435,7 +435,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -443,7 +443,7 @@ class TestDispatcherManager:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -452,23 +452,23 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
consumer="api-gateway-default-test_flow-agent-request",
|
||||
subscriber="api-gateway-default-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -479,36 +479,36 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-load": {"queue": "text_load_queue"}
|
||||
"text-load": {"flow": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
queue={"queue": "text_load_queue"}
|
||||
queue="text_load_queue"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -519,7 +519,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
|
|
@ -529,14 +529,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
|
|
@ -546,7 +546,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
|
|
@ -558,7 +558,7 @@ class TestDispatcherManager:
|
|||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
|
||||
|
|
@ -608,7 +608,7 @@ class TestDispatcherManager:
|
|||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -630,7 +630,7 @@ class TestDispatcherManager:
|
|||
mock_rr_dispatchers.__contains__.return_value = True
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
for _ in range(5)
|
||||
])
|
||||
|
||||
|
|
@ -638,5 +638,5 @@ class TestDispatcherManager:
|
|||
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||
)
|
||||
assert mock_dispatcher.start.call_count == 1
|
||||
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
|
||||
assert all(r == "result" for r in results)
|
||||
75
tests/unit/test_gateway/test_endpoint_i18n.py
Normal file
75
tests/unit/test_gateway/test_endpoint_i18n.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Tests for Gateway i18n pack endpoint."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
|
||||
|
||||
|
||||
class TestI18nPackEndpoint:
|
||||
|
||||
def test_i18n_endpoint_initialization(self):
|
||||
mock_auth = MagicMock()
|
||||
|
||||
endpoint = I18nPackEndpoint(
|
||||
endpoint_path="/api/v1/i18n/packs/{lang}",
|
||||
auth=mock_auth,
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_i18n_endpoint_start_method(self):
|
||||
mock_auth = MagicMock()
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_get_handler(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
request.headers = {"Authorization": "Token abc"}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
assert isinstance(resp, web.HTTPUnauthorized)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_pack_when_permitted(self):
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
|
||||
|
||||
request = MagicMock()
|
||||
request.path = "/api/v1/i18n/packs/en"
|
||||
request.headers = {}
|
||||
request.match_info = {"lang": "en"}
|
||||
|
||||
resp = await endpoint.handle(request)
|
||||
|
||||
assert resp.status == 200
|
||||
payload = json.loads(resp.body.decode("utf-8"))
|
||||
assert isinstance(payload, dict)
|
||||
assert "cli.verify_system_status.title" in payload
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
"""
|
||||
Unit tests for entity contexts import dispatcher.
|
||||
|
||||
Tests the business logic of EntityContextsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version constructed Metadata(metadata=...)
|
||||
which raised TypeError at runtime as soon as a message was received. These
|
||||
tests exercise receive() end-to-end so any future schema/kwarg drift in
|
||||
the Metadata or EntityContexts construction is caught immediately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
|
||||
from trustgraph.schema import EntityContexts, EntityContext, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample entity-contexts websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"context": "Alice is a person.",
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"context": "Bob is a person.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestEntityContextsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ec-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ec-queue",
|
||||
schema=EntityContexts,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestEntityContextsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestEntityContextsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_entity_contexts_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata or EntityContexts gain/lose kwargs, this raises
|
||||
# TypeError — exactly the regression we want to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityContext) for e in sent.entities)
|
||||
assert sent.entities[0].context == "Alice is a person."
|
||||
assert sent.entities[1].context == "Bob is a person."
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -158,7 +158,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
|
|
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
|
|
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
Unit tests for graph embeddings import dispatcher.
|
||||
|
||||
Tests the business logic of GraphEmbeddingsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version of EntityContextsImport
|
||||
constructed Metadata(metadata=...) which raised TypeError at runtime as
|
||||
soon as a message was received. The same shape of bug can occur here, so
|
||||
these tests exercise receive() end-to-end to catch any future schema or
|
||||
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
|
||||
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample graph-embeddings websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ge-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ge-queue",
|
||||
schema=GraphEmbeddings,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_graph_embeddings_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
|
||||
# kwargs, this raises TypeError — exactly the regression we want
|
||||
# to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
|
||||
# Lock in the wire format: incoming "vector" key (singular,
|
||||
# list[float]) maps to EntityEmbeddings.vector. This mirrors
|
||||
# serialize_graph_embeddings() on the export side.
|
||||
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
|
|||
|
||||
# Check metadata
|
||||
assert sent_object.metadata.id == "obj-123"
|
||||
assert sent_object.metadata.user == "testuser"
|
||||
assert sent_object.metadata.collection == "testcollection"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
|
|
|
|||
|
|
@ -171,6 +171,14 @@ class TestApi:
|
|||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
|
|||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,10 +29,9 @@ class Triple:
|
|||
self.o = o
|
||||
|
||||
class Metadata:
|
||||
def __init__(self, id, user, collection, root=""):
|
||||
def __init__(self, id, collection, root=""):
|
||||
self.id = id
|
||||
self.root = root
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
|
||||
class Triples:
|
||||
|
|
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
|
|||
"""Sample Triples batch object"""
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +121,6 @@ def sample_chunk():
|
|||
"""Sample text chunk for processing"""
|
||||
metadata = Metadata(
|
||||
id="test-chunk-456",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_triples, Triples)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_triples.metadata.id == sample_metadata.id
|
||||
assert sent_triples.metadata.user == sample_metadata.user
|
||||
assert sent_triples.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
|
@ -346,7 +345,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_contexts, EntityContexts)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_contexts.metadata.id == sample_metadata.id
|
||||
assert sent_contexts.metadata.user == sample_metadata.user
|
||||
assert sent_contexts.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
|
|
|||
|
|
@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
|
|||
"""Test ExtractedObject creation and properties"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-extraction-001",
|
||||
user="test_user",
|
||||
id="test-extraction-001",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
|
|||
assert extracted_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert extracted_obj.confidence == 0.95
|
||||
assert "John Doe" in extracted_obj.source_span
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
|
||||
def test_config_parsing_error_handling(self):
|
||||
"""Test configuration parsing with invalid JSON"""
|
||||
|
|
|
|||
|
|
@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
|
|||
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
|
|||
# Assert
|
||||
assert isinstance(triples_batch, Triples)
|
||||
assert triples_batch.metadata.id == "test-doc-123"
|
||||
assert triples_batch.metadata.user == "test_user"
|
||||
assert triples_batch.metadata.collection == "test_collection"
|
||||
assert len(triples_batch.triples) == 2
|
||||
|
||||
|
|
|
|||
119
tests/unit/test_librarian/test_blob_store.py
Normal file
119
tests/unit/test_librarian/test_blob_store.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import asyncio
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from uuid import uuid4
|
||||
from minio.error import S3Error
|
||||
from trustgraph.librarian.blob_store import BlobStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_blob_store():
|
||||
"""Create a BlobStore with mocked Minio client."""
|
||||
mock_minio = MagicMock()
|
||||
with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio):
|
||||
# Prevent ensure_bucket from making network calls during init
|
||||
with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'):
|
||||
store = BlobStore(
|
||||
endpoint="localhost:9000",
|
||||
access_key="access",
|
||||
secret_key="secret",
|
||||
bucket_name="test-bucket"
|
||||
)
|
||||
return store, mock_minio
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_success_no_retry():
|
||||
store, mock_minio = _make_blob_store()
|
||||
object_id = uuid4()
|
||||
|
||||
await store.add(object_id, b"data", "text/plain")
|
||||
|
||||
mock_minio.put_object.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_recovery_on_transient_failure():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0 # Disable delay for fast tests
|
||||
|
||||
# Fail twice, succeed third time
|
||||
mock_minio.put_object.side_effect = [
|
||||
Exception("Error 1"),
|
||||
Exception("Error 2"),
|
||||
MagicMock()
|
||||
]
|
||||
|
||||
await store.add(uuid4(), b"data", "text/plain")
|
||||
|
||||
assert mock_minio.put_object.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_exhaustion_after_8_attempts():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0
|
||||
|
||||
# Permanent failure
|
||||
mock_minio.put_object.side_effect = Exception("Permanent failure")
|
||||
|
||||
with pytest.raises(Exception, match="Permanent failure"):
|
||||
await store.add(uuid4(), b"data", "text/plain")
|
||||
|
||||
# Author requirement: exactly 8 attempts
|
||||
assert mock_minio.put_object.call_count == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_error_triggers_retry():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0
|
||||
|
||||
# Mock S3Error
|
||||
s3_err = S3Error("code", "msg", "res", "req", "host", None)
|
||||
mock_minio.get_object.side_effect = [s3_err, MagicMock()]
|
||||
|
||||
await store.get(uuid4())
|
||||
|
||||
assert mock_minio.get_object.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exponential_backoff_delays():
|
||||
store, mock_minio = _make_blob_store()
|
||||
# Use real base_delay to check math
|
||||
store.base_delay = 0.25
|
||||
|
||||
# Correct method name is stat_object, not get_size
|
||||
mock_minio.stat_object = MagicMock(side_effect=Exception("Wait"))
|
||||
|
||||
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
||||
with pytest.raises(Exception):
|
||||
await store.get_size(uuid4())
|
||||
|
||||
# Should have 7 sleep calls for 8 attempts
|
||||
assert mock_sleep.call_count == 7
|
||||
|
||||
# Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0
|
||||
sleep_args = [call[0][0] for call in mock_sleep.call_args_list]
|
||||
assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runs_in_executor():
|
||||
"""Verify that synchronous Minio calls are offloaded to an executor."""
|
||||
store, mock_minio = _make_blob_store()
|
||||
|
||||
# Mock response object with .read() method
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"result"
|
||||
|
||||
with patch('asyncio.get_event_loop') as mock_loop:
|
||||
mock_loop_instance = MagicMock()
|
||||
mock_loop.return_value = mock_loop_instance
|
||||
mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response)
|
||||
|
||||
await store.get(uuid4())
|
||||
|
||||
mock_loop_instance.run_in_executor.assert_called_once()
|
||||
|
|
@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1):
|
|||
"""Create a Librarian with mocked blob_store and table_store."""
|
||||
lib = Librarian.__new__(Librarian)
|
||||
lib.blob_store = MagicMock()
|
||||
lib.blob_store.create_multipart_upload = AsyncMock()
|
||||
lib.blob_store.upload_part = AsyncMock()
|
||||
lib.blob_store.complete_multipart_upload = AsyncMock()
|
||||
lib.blob_store.abort_multipart_upload = AsyncMock()
|
||||
lib.table_store = AsyncMock()
|
||||
lib.load_document = AsyncMock()
|
||||
lib.min_chunk_size = min_chunk_size
|
||||
|
|
@ -29,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
|
|||
|
||||
|
||||
def _make_doc_metadata(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
|
||||
):
|
||||
meta = MagicMock()
|
||||
meta.id = doc_id
|
||||
meta.kind = kind
|
||||
meta.user = user
|
||||
meta.workspace = workspace
|
||||
meta.title = title
|
||||
meta.time = 1700000000
|
||||
meta.comments = ""
|
||||
|
|
@ -43,27 +47,27 @@ def _make_doc_metadata(
|
|||
|
||||
|
||||
def _make_begin_request(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice",
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice",
|
||||
total_size=10_000_000, chunk_size=0
|
||||
):
|
||||
req = MagicMock()
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
|
||||
req.total_size = total_size
|
||||
req.chunk_size = chunk_size
|
||||
return req
|
||||
|
||||
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
|
||||
req = MagicMock()
|
||||
req.upload_id = upload_id
|
||||
req.chunk_index = chunk_index
|
||||
req.user = user
|
||||
req.workspace = workspace
|
||||
req.content = base64.b64encode(content)
|
||||
return req
|
||||
|
||||
|
||||
def _make_session(
|
||||
user="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
workspace="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
total_size=10_000_000, chunks_received=None, object_id="obj-1",
|
||||
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
|
||||
):
|
||||
|
|
@ -72,11 +76,11 @@ def _make_session(
|
|||
if document_metadata is None:
|
||||
document_metadata = json.dumps({
|
||||
"id": document_id, "kind": "application/pdf",
|
||||
"user": user, "title": "Test", "time": 1700000000,
|
||||
"workspace": workspace, "title": "Test", "time": 1700000000,
|
||||
"comments": "", "tags": [],
|
||||
})
|
||||
return {
|
||||
"user": user,
|
||||
"workspace": workspace,
|
||||
"total_chunks": total_chunks,
|
||||
"chunk_size": chunk_size,
|
||||
"total_size": total_size,
|
||||
|
|
@ -255,10 +259,10 @@ class TestUploadChunk:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(user="bob")
|
||||
req = _make_upload_chunk_request(workspace="bob")
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
|
|
@ -349,7 +353,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -371,7 +375,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -390,7 +394,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="Missing chunks"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -402,7 +406,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -410,12 +414,12 @@ class TestCompleteUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -435,7 +439,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.abort_upload(req)
|
||||
|
||||
|
|
@ -452,7 +456,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -460,12 +464,12 @@ class TestAbortUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -488,7 +492,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -506,7 +510,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-expired"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -523,7 +527,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -535,12 +539,12 @@ class TestGetUploadStatus:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.get_upload_status(req)
|
||||
|
|
@ -560,7 +564,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -583,7 +587,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -604,7 +608,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -626,7 +630,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 0 # Should use default 1MB
|
||||
|
||||
|
|
@ -645,7 +649,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 1000
|
||||
|
||||
|
|
@ -662,7 +666,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 512
|
||||
|
||||
|
|
@ -694,7 +698,7 @@ class TestListUploads:
|
|||
]
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
@ -709,7 +713,7 @@ class TestListUploads:
|
|||
lib.table_store.list_upload_sessions.return_value = []
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
|
|||
590
tests/unit/test_provenance/test_dag_structure.py
Normal file
590
tests/unit/test_provenance/test_dag_structure.py
Normal file
|
|
@ -0,0 +1,590 @@
|
|||
"""
|
||||
DAG structure tests for provenance chains.
|
||||
|
||||
Verifies that the wasDerivedFrom chain has the expected shape for each
|
||||
service. These tests catch structural regressions when new entities are
|
||||
inserted into the chain (e.g. PatternDecision between session and first
|
||||
iteration).
|
||||
|
||||
Expected chains:
|
||||
|
||||
GraphRAG: question → grounding → exploration → focus → synthesis
|
||||
DocumentRAG: question → grounding → exploration → synthesis
|
||||
Agent React: session → pattern-decision → iteration → (observation → iteration)* → final
|
||||
Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis
|
||||
Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
|
||||
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
|
||||
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_PATTERN, TG_TASK_TYPE,
|
||||
)
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_events(events):
|
||||
"""Build a dict of explain_id → {types, derived_from, triples}."""
|
||||
result = {}
|
||||
for ev in events:
|
||||
eid = ev["explain_id"]
|
||||
triples = ev["triples"]
|
||||
types = {
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == RDF_TYPE
|
||||
}
|
||||
parents = [
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
|
||||
]
|
||||
result[eid] = {
|
||||
"types": types,
|
||||
"derived_from": parents[0] if parents else None,
|
||||
"triples": triples,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _find_by_type(dag, rdf_type):
|
||||
"""Find all event IDs that have the given rdf:type."""
|
||||
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
|
||||
|
||||
|
||||
def _assert_chain(dag, chain_types):
|
||||
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
|
||||
for i in range(1, len(chain_types)):
|
||||
parent_type = chain_types[i - 1]
|
||||
child_type = chain_types[i]
|
||||
parents = _find_by_type(dag, parent_type)
|
||||
children = _find_by_type(dag, child_type)
|
||||
assert parents, f"No entity with type {parent_type}"
|
||||
assert children, f"No entity with type {child_type}"
|
||||
# At least one child must derive from at least one parent
|
||||
linked = False
|
||||
for child_id in children:
|
||||
derived = dag[child_id]["derived_from"]
|
||||
if derived in parents:
|
||||
linked = True
|
||||
break
|
||||
assert linked, (
|
||||
f"No {child_type} derives from {parent_type}. "
|
||||
f"Children derive from: "
|
||||
f"{[dag[c]['derived_from'] for c in children]}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → focus → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
|
||||
]
|
||||
triples_client.query_stream.return_value = [
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/e1"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="value"),
|
||||
)
|
||||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = GraphRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock(return_value="Chunk content.")
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.9),
|
||||
]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text", text="Answer.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = DocumentRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent DAG structure — tested via service.agent_request()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_processor(tools=None):
|
||||
processor = MagicMock()
|
||||
processor.max_iterations = 10
|
||||
processor.save_answer_content = AsyncMock()
|
||||
|
||||
def mock_session_uri(sid):
|
||||
return f"urn:trustgraph:agent:session:{sid}"
|
||||
processor.provenance_session_uri.side_effect = mock_session_uri
|
||||
|
||||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agents = {"default": agent}
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def _make_flow():
|
||||
producers = {}
|
||||
|
||||
def factory(name):
|
||||
if name not in producers:
|
||||
producers[name] = AsyncMock()
|
||||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
flow.workspace = "default"
|
||||
return flow
|
||||
|
||||
|
||||
def _collect_agent_events(respond_mock):
|
||||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"triples": resp.explain_triples,
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
class TestAgentReactDagStructure:
|
||||
"""
|
||||
Via service.agent_request(), full two-iteration react chain:
|
||||
session → pattern-decision → iteration(1) → observation(1) → final
|
||||
|
||||
Iteration 1: tool call → observation
|
||||
Iteration 2: final answer
|
||||
"""
|
||||
|
||||
def _make_service(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "lookup"
|
||||
mock_tool.description = "Look things up"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="42")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"lookup": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
service = self._make_service()
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
|
||||
action = Action(
|
||||
thought="I need to look this up",
|
||||
name="lookup",
|
||||
arguments={"question": "6x7"},
|
||||
observation="",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter1(on_action=None, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
action.observation = "42"
|
||||
return action
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter1
|
||||
|
||||
request1 = AgentRequest(
|
||||
question="What is 6x7?",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="react",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# next_fn should have been called with updated history
|
||||
assert next_fn.called
|
||||
|
||||
# Iteration 2: final answer
|
||||
final = Final(thought="The answer is 42", final="42")
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter2(**kwargs):
|
||||
return final
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter2
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
# Collect and verify DAG
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
|
||||
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
|
||||
final_ids = _find_by_type(dag, TG_CONCLUSION)
|
||||
|
||||
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
|
||||
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
|
||||
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
|
||||
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
|
||||
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
|
||||
|
||||
# Full chain:
|
||||
# session → pattern-decision
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
|
||||
# pattern-decision → iteration(1)
|
||||
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# iteration(1) → observation(1)
|
||||
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
|
||||
|
||||
# observation(1) → final
|
||||
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
|
||||
|
||||
|
||||
class TestAgentPlanDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → plan → step-result → synthesis
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
# Mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query KB"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="Found it")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"knowledge-query": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
# Mock prompt client
|
||||
mock_prompt_client = AsyncMock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(id, variables=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if id == "plan-create":
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
|
||||
)
|
||||
elif id == "plan-step-execute":
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
|
||||
)
|
||||
elif id == "plan-synthesise":
|
||||
return PromptResult(response_type="text", text="Final answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: planning
|
||||
request1 = AgentRequest(
|
||||
question="Test?",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="plan-then-execute",
|
||||
history=[],
|
||||
)
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# Iteration 2: execute step (next_fn was called with updated request)
|
||||
assert next_fn.called
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
# Iteration 3: all steps done → synthesis
|
||||
# Simulate completed step in history
|
||||
next_request.history[-1].plan[0].status = "completed"
|
||||
next_request.history[-1].plan[0].result = "Found it"
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
|
||||
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(plan_ids) == 1
|
||||
assert len(synthesis_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → plan → ... → synthesis
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
|
||||
class TestAgentSupervisorDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → decomposition → (fan-out)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = _make_processor()
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = AgentRequest(
|
||||
question="Research quantum computing",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
pattern="supervisor",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(decomp_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → decomposition
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# Fan-out should have been called
|
||||
assert next_fn.call_count == 2 # One per goal
|
||||
|
|
@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
|
|||
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
|
||||
|
||||
def test_chunk_entity_has_chunk_type(self):
|
||||
def test_chunk_entity_has_message_type(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"chunker", "1.0",
|
||||
|
|
|
|||
131
tests/unit/test_pubsub/test_kafka_backend.py
Normal file
131
tests/unit/test_pubsub/test_kafka_backend.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
Unit tests for Kafka backend — topic parsing and factory dispatch.
|
||||
Does not require a running Kafka instance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
from trustgraph.base.kafka_backend import KafkaBackend
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
|
||||
|
||||
class TestKafkaParseTopic:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
b = object.__new__(KafkaBackend)
|
||||
return b
|
||||
|
||||
def test_flow_is_durable(self, backend):
|
||||
name, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
|
||||
assert durable is True
|
||||
assert cls == 'flow'
|
||||
assert name == 'tg.flow.text-completion-request'
|
||||
|
||||
def test_notify_is_not_durable(self, backend):
|
||||
name, cls, durable = backend._parse_topic('notify:tg:config')
|
||||
assert durable is False
|
||||
assert cls == 'notify'
|
||||
assert name == 'tg.notify.config'
|
||||
|
||||
def test_request_is_not_durable(self, backend):
|
||||
name, cls, durable = backend._parse_topic('request:tg:config')
|
||||
assert durable is False
|
||||
assert cls == 'request'
|
||||
assert name == 'tg.request.config'
|
||||
|
||||
def test_response_is_not_durable(self, backend):
|
||||
name, cls, durable = backend._parse_topic('response:tg:librarian')
|
||||
assert durable is False
|
||||
assert cls == 'response'
|
||||
assert name == 'tg.response.librarian'
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
name, cls, durable = backend._parse_topic('flow:prod:my-queue')
|
||||
assert name == 'prod.flow.my-queue'
|
||||
assert durable is True
|
||||
|
||||
def test_no_colon_defaults_to_flow(self, backend):
|
||||
name, cls, durable = backend._parse_topic('simple-queue')
|
||||
assert name == 'tg.flow.simple-queue'
|
||||
assert cls == 'flow'
|
||||
assert durable is True
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid topic class"):
|
||||
backend._parse_topic('unknown:tg:topic')
|
||||
|
||||
def test_topic_with_flow_suffix(self, backend):
|
||||
"""Topic names with flow suffix (e.g. :default) have colons replaced with dots."""
|
||||
name, cls, durable = backend._parse_topic('request:tg:prompt:default')
|
||||
assert name == 'tg.request.prompt.default'
|
||||
|
||||
|
||||
class TestKafkaRetention:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
b = object.__new__(KafkaBackend)
|
||||
return b
|
||||
|
||||
def test_flow_gets_long_retention(self, backend):
|
||||
assert backend._retention_ms('flow') == 7 * 24 * 60 * 60 * 1000
|
||||
|
||||
def test_request_gets_short_retention(self, backend):
|
||||
assert backend._retention_ms('request') == 300 * 1000
|
||||
|
||||
def test_response_gets_short_retention(self, backend):
|
||||
assert backend._retention_ms('response') == 300 * 1000
|
||||
|
||||
def test_notify_gets_short_retention(self, backend):
|
||||
assert backend._retention_ms('notify') == 300 * 1000
|
||||
|
||||
|
||||
class TestGetPubsubKafka:
|
||||
|
||||
def test_factory_creates_kafka_backend(self):
|
||||
backend = get_pubsub(pubsub_backend='kafka')
|
||||
assert isinstance(backend, KafkaBackend)
|
||||
|
||||
def test_factory_passes_config(self):
|
||||
backend = get_pubsub(
|
||||
pubsub_backend='kafka',
|
||||
kafka_bootstrap_servers='myhost:9093',
|
||||
kafka_security_protocol='SASL_SSL',
|
||||
kafka_sasl_mechanism='PLAIN',
|
||||
kafka_sasl_username='user',
|
||||
kafka_sasl_password='pass',
|
||||
)
|
||||
assert isinstance(backend, KafkaBackend)
|
||||
assert backend._bootstrap_servers == 'myhost:9093'
|
||||
assert backend._admin_config['security.protocol'] == 'SASL_SSL'
|
||||
assert backend._admin_config['sasl.mechanism'] == 'PLAIN'
|
||||
assert backend._admin_config['sasl.username'] == 'user'
|
||||
assert backend._admin_config['sasl.password'] == 'pass'
|
||||
|
||||
|
||||
class TestAddPubsubArgsKafka:
|
||||
|
||||
def test_kafka_args_present(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([
|
||||
'--pubsub-backend', 'kafka',
|
||||
'--kafka-bootstrap-servers', 'myhost:9093',
|
||||
])
|
||||
assert args.pubsub_backend == 'kafka'
|
||||
assert args.kafka_bootstrap_servers == 'myhost:9093'
|
||||
|
||||
def test_kafka_defaults_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.kafka_bootstrap_servers == 'kafka:9092'
|
||||
assert args.kafka_security_protocol == 'PLAINTEXT'
|
||||
|
||||
def test_kafka_standalone_defaults_to_localhost(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args([])
|
||||
assert args.kafka_bootstrap_servers == 'localhost:9092'
|
||||
|
|
@ -1,18 +1,16 @@
|
|||
"""
|
||||
Unit tests for RabbitMQ backend — queue name mapping and factory dispatch.
|
||||
Unit tests for RabbitMQ backend — topic parsing and factory dispatch.
|
||||
Does not require a running RabbitMQ instance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
pika = pytest.importorskip("pika", reason="pika not installed")
|
||||
|
||||
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
|
||||
|
||||
class TestRabbitMQMapQueueName:
|
||||
class TestRabbitMQParseTopic:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
|
|
@ -20,43 +18,48 @@ class TestRabbitMQMapQueueName:
|
|||
return b
|
||||
|
||||
def test_flow_is_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
|
||||
exchange, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
|
||||
assert durable is True
|
||||
assert name == 'tg.flow.text-completion-request'
|
||||
assert cls == 'flow'
|
||||
assert exchange == 'tg.flow.text-completion-request'
|
||||
|
||||
def test_notify_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('notify:tg:config')
|
||||
exchange, cls, durable = backend._parse_topic('notify:tg:config')
|
||||
assert durable is False
|
||||
assert name == 'tg.notify.config'
|
||||
assert cls == 'notify'
|
||||
assert exchange == 'tg.notify.config'
|
||||
|
||||
def test_request_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('request:tg:config')
|
||||
exchange, cls, durable = backend._parse_topic('request:tg:config')
|
||||
assert durable is False
|
||||
assert name == 'tg.request.config'
|
||||
assert cls == 'request'
|
||||
assert exchange == 'tg.request.config'
|
||||
|
||||
def test_response_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('response:tg:librarian')
|
||||
exchange, cls, durable = backend._parse_topic('response:tg:librarian')
|
||||
assert durable is False
|
||||
assert name == 'tg.response.librarian'
|
||||
assert cls == 'response'
|
||||
assert exchange == 'tg.response.librarian'
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:prod:my-queue')
|
||||
assert name == 'prod.flow.my-queue'
|
||||
exchange, cls, durable = backend._parse_topic('flow:prod:my-queue')
|
||||
assert exchange == 'prod.flow.my-queue'
|
||||
assert durable is True
|
||||
|
||||
def test_no_colon_defaults_to_flow(self, backend):
|
||||
name, durable = backend.map_queue_name('simple-queue')
|
||||
assert name == 'tg.simple-queue'
|
||||
assert durable is False
|
||||
exchange, cls, durable = backend._parse_topic('simple-queue')
|
||||
assert exchange == 'tg.flow.simple-queue'
|
||||
assert cls == 'flow'
|
||||
assert durable is True
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||
backend.map_queue_name('unknown:tg:topic')
|
||||
with pytest.raises(ValueError, match="Invalid topic class"):
|
||||
backend._parse_topic('unknown:tg:topic')
|
||||
|
||||
def test_flow_with_flow_suffix(self, backend):
|
||||
"""Queue names with flow suffix (e.g. :default) are preserved."""
|
||||
name, durable = backend.map_queue_name('request:tg:prompt:default')
|
||||
assert name == 'tg.request.prompt:default'
|
||||
def test_topic_with_flow_suffix(self, backend):
|
||||
"""Topic names with flow suffix (e.g. :default) are preserved."""
|
||||
exchange, cls, durable = backend._parse_topic('request:tg:prompt:default')
|
||||
assert exchange == 'tg.request.prompt:default'
|
||||
|
||||
|
||||
class TestGetPubsubRabbitMQ:
|
||||
|
|
|
|||
|
|
@ -304,14 +304,14 @@ class TestStreamingTypes:
|
|||
|
||||
assert chunk.content == "thinking..."
|
||||
assert chunk.end_of_message is False
|
||||
assert chunk.chunk_type == "thought"
|
||||
assert chunk.message_type == "thought"
|
||||
|
||||
def test_agent_observation_creation(self):
|
||||
"""Test creating AgentObservation chunk"""
|
||||
chunk = AgentObservation(content="observing...", end_of_message=False)
|
||||
|
||||
assert chunk.content == "observing..."
|
||||
assert chunk.chunk_type == "observation"
|
||||
assert chunk.message_type == "observation"
|
||||
|
||||
def test_agent_answer_creation(self):
|
||||
"""Test creating AgentAnswer chunk"""
|
||||
|
|
@ -324,7 +324,7 @@ class TestStreamingTypes:
|
|||
assert chunk.content == "answer"
|
||||
assert chunk.end_of_message is True
|
||||
assert chunk.end_of_dialog is True
|
||||
assert chunk.chunk_type == "final-answer"
|
||||
assert chunk.message_type == "final-answer"
|
||||
|
||||
def test_rag_chunk_creation(self):
|
||||
"""Test creating RAGChunk"""
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||
"""Test querying document embeddings with a longer vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=3
|
||||
|
|
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the specified limit
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying document embeddings with empty vectors list"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying document embeddings with empty search results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_unicode_documents(self, processor):
|
||||
"""Test querying document embeddings with Unicode document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify Unicode content is preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
"""Test querying document embeddings with large document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify large content is preserved in ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
"""Test querying document embeddings with special characters in documents"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify special characters are preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying document embeddings with negative limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for negative limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_document_embeddings(query)
|
||||
await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||
limit=5
|
||||
|
|
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||
"""Test querying document embeddings with multiple results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
chunks = await processor.query_document_embeddings(mock_query_message)
|
||||
chunks = await processor.query_document_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
|
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is passed to query
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
|
|
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert chunks == []
|
||||
|
|
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode content is properly handled
|
||||
assert len(chunks) == 2
|
||||
|
|
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify large content is properly handled
|
||||
assert len(chunks) == 1
|
||||
|
|
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all content types are properly handled
|
||||
assert len(chunks) == 5
|
||||
|
|
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
|
|
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
processor.pinecone.Index.side_effect = Exception("Index access failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index access failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
|
|
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all queries were made
|
||||
assert mock_index.query.call_count == 3
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
|
|
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once with correct collection
|
||||
|
|
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('utf8_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
|
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('large_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
|
|
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['chunk_id']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('payload_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||
"""Test querying graph embeddings returns multiple results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_with_limit(self, processor):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with 2*limit for better deduplication
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||
"""Test that query results preserve order from the vector store"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify results are in the same order as returned by the store
|
||||
assert len(result) == 3
|
||||
|
|
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||
"""Test that results are properly limited when store returns more than requested"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=2
|
||||
|
|
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying graph embeddings with empty vectors list"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying graph embeddings with empty search results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
|
||||
"""Test querying graph embeddings with mixed URI and literal results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
|
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_graph_embeddings(query)
|
||||
await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||
"""Test querying graph embeddings with a longer vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
limit=5
|
||||
|
|
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
entities = await processor.query_graph_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
|
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is respected
|
||||
assert len(entities) == 2
|
||||
|
|
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
|
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert entities == []
|
||||
|
|
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
|
|
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should only return 2 entities (respecting limit)
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_graph_embeddings(message)
|
||||
await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
|
|
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('uri_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
await processor.query_graph_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryUserCollectionIsolation:
|
||||
class TestMemgraphQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Memgraph query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
|
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryUserCollectionIsolation:
|
||||
class TestNeo4jQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Neo4j query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify SP query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_node_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
|
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
|
|||
}
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert "customer" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
|
|||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
# Verify per-workspace schema builder was created and graphql schema built
|
||||
assert "default" in processor.schema_builders
|
||||
assert "default" in processor.graphql_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
"""Test GraphQL execution context setup"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
graphql_schema.execute.assert_called_once()
|
||||
call_args = graphql_schema.execute.call_args
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["workspace"] == "default"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
# Verify result structure
|
||||
|
|
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
|
|||
async def test_error_handling_graphql_errors(self):
|
||||
"""Test GraphQL error handling and conversion"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
|
|
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
|
|||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "default"
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
|
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
variables={},
|
||||
|
|
@ -330,7 +330,8 @@ class TestUnifiedTableQueries:
|
|||
"""Test queries against the unified rows table"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_index_match(self):
|
||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
||||
async def test_query_with_index_match(self, mock_async_execute):
|
||||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
|
|
@ -340,10 +341,10 @@ class TestUnifiedTableQueries:
|
|||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
# Mock async_execute to return test data
|
||||
mock_row = MagicMock()
|
||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||
processor.session.execute.return_value = [mock_row]
|
||||
mock_async_execute.return_value = [mock_row]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
|
|
@ -356,7 +357,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
@ -366,14 +367,14 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Verify Cassandra was connected and queried
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
mock_async_execute.assert_called_once()
|
||||
|
||||
# Verify query structure - should query unified rows table
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
params = call_args[0][1]
|
||||
call_args = mock_async_execute.call_args
|
||||
query = call_args[0][1]
|
||||
params = call_args[0][2]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "SELECT data, source FROM test_workspace.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
|
|
@ -390,7 +391,8 @@ class TestUnifiedTableQueries:
|
|||
assert results[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_without_index_match(self):
|
||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
||||
async def test_query_without_index_match(self, mock_async_execute):
|
||||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
|
|
@ -401,12 +403,12 @@ class TestUnifiedTableQueries:
|
|||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
# Mock async_execute to return test data
|
||||
mock_row1 = MagicMock()
|
||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||
mock_row2 = MagicMock()
|
||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
||||
processor.session.execute.return_value = [mock_row1, mock_row2]
|
||||
mock_async_execute.return_value = [mock_row1, mock_row2]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
|
|
@ -419,7 +421,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
@ -428,8 +430,8 @@ class TestUnifiedTableQueries:
|
|||
)
|
||||
|
||||
# Query should use ALLOW FILTERING for scan
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
call_args = mock_async_execute.call_args
|
||||
query = call_args[0][1]
|
||||
|
||||
assert "ALLOW FILTERING" in query
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with correct parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
|
|
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
|
|
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
|
|
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
|
|
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
|
|
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with authentication
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
|
|
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
|
|
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# PO query pattern (predicate + object, find subjects)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# OS query pattern (object + subject, find predicates)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
|
|
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_tg_instance.reset_mock()
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value=s) if s else None,
|
||||
p=Term(type=LITERAL, value=p) if p else None,
|
||||
|
|
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=10
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify the correct method was called
|
||||
method = getattr(mock_tg_instance, expected_method)
|
||||
|
|
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# This is the query pattern that was slow with ALLOW FILTERING
|
||||
query = TriplesQueryRequest(
|
||||
user='large_dataset_user',
|
||||
collection='massive_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
|
||||
|
|
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue