Merge remote-tracking branch 'origin/master' into ts-port

This commit is contained in:
elpresidank 2026-04-26 20:07:57 -05:00
commit f8252ecd54
1038 changed files with 253274 additions and 8466 deletions

View file

@ -8,6 +8,7 @@ import pytest
# import asyncio
# import tracemalloc
# import warnings
import logging
from unittest.mock import MagicMock
# Uncomment the lines below to enable asyncio debug mode and tracemalloc
@ -33,19 +34,14 @@ def mock_loki_handler(session_mocker=None):
# Create a mock LokiHandler that does nothing
original_loki_handler = logging_loki.LokiHandler
class MockLokiHandler:
class MockLokiHandler(logging.Handler):
"""Mock LokiHandler that doesn't make network calls."""
def __init__(self, *args, **kwargs):
pass
super().__init__()
def emit(self, record):
pass
def flush(self):
pass
def close(self):
pass
return
# Replace the real LokiHandler with our mock
logging_loki.LokiHandler = MockLokiHandler

View file

@ -72,7 +72,6 @@ def sample_message_data():
},
"DocumentRagQuery": {
"query": "What is artificial intelligence?",
"user": "test_user",
"collection": "test_collection",
"doc_limit": 10
},
@ -87,7 +86,7 @@ def sample_message_data():
"history": []
},
"AgentResponse": {
"chunk_type": "answer",
"message_type": "answer",
"content": "Machine learning is a subset of AI.",
"end_of_message": True,
"end_of_dialog": True,
@ -95,7 +94,6 @@ def sample_message_data():
},
"Metadata": {
"id": "test-doc-123",
"user": "test_user",
"collection": "test_collection"
},
"Term": {
@ -130,9 +128,8 @@ def invalid_message_data():
{}, # Missing required fields
],
"DocumentRagQuery": [
{"query": None, "user": "test", "collection": "test", "doc_limit": 10}, # Invalid query
{"query": "test", "user": None, "collection": "test", "doc_limit": 10}, # Invalid user
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
{"query": None, "collection": "test", "doc_limit": 10}, # Invalid query
{"query": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
{"query": "test"}, # Missing required fields
],
"Term": [

View file

@ -18,24 +18,18 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_schema_fields(self):
"""Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest(
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vector')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection"
def test_request_translator_decode(self):
@ -45,7 +39,6 @@ class TestDocumentEmbeddingsRequestContract:
data = {
"vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
}
@ -54,7 +47,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection"
def test_request_translator_decode_with_defaults(self):
@ -63,7 +55,7 @@ class TestDocumentEmbeddingsRequestContract:
data = {
"vector": [0.1, 0.2]
# No limit, user, or collection provided
# No limit or collection provided
}
result = translator.decode(data)
@ -71,7 +63,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
def test_request_translator_encode(self):
@ -81,7 +72,6 @@ class TestDocumentEmbeddingsRequestContract:
request = DocumentEmbeddingsRequest(
vector=[0.5, 0.6],
limit=20,
user="test_user",
collection="test_collection"
)
@ -90,7 +80,6 @@ class TestDocumentEmbeddingsRequestContract:
assert isinstance(result, dict)
assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection"
@ -219,7 +208,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
request_data = {
"vector": [0.1, 0.2, 0.3],
"limit": 5,
"user": "test_user",
"collection": "test_collection"
}

View file

@ -132,7 +132,6 @@ class TestDocumentRagMessageContracts:
# Test required fields
query = DocumentRagQuery(**query_data)
assert hasattr(query, 'query')
assert hasattr(query, 'user')
assert hasattr(query, 'collection')
assert hasattr(query, 'doc_limit')
@ -154,12 +153,10 @@ class TestDocumentRagMessageContracts:
# Test valid query
valid_query = DocumentRagQuery(
query="What is AI?",
user="test_user",
collection="test_collection",
doc_limit=5
)
assert valid_query.query == "What is AI?"
assert valid_query.user == "test_user"
assert valid_query.collection == "test_collection"
assert valid_query.doc_limit == 5
@ -212,7 +209,7 @@ class TestAgentMessageContracts:
# Test required fields
response = AgentResponse(**response_data)
assert hasattr(response, 'chunk_type')
assert hasattr(response, 'message_type')
assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog')
@ -400,7 +397,6 @@ class TestMetadataMessageContracts:
metadata = Metadata(**metadata_data)
assert metadata.id == "test-doc-123"
assert metadata.user == "test_user"
assert metadata.collection == "test_collection"
def test_error_schema_contract(self):
@ -491,7 +487,7 @@ class TestSchemaEvolutionContracts:
required_fields = {
"TextCompletionRequest": ["system", "prompt"],
"TextCompletionResponse": ["error", "response", "model"],
"DocumentRagQuery": ["query", "user", "collection"],
"DocumentRagQuery": ["query", "collection"],
"DocumentRagResponse": ["error", "response"],
"AgentRequest": ["question", "history"],
"AgentResponse": ["error"],

View file

@ -18,7 +18,6 @@ class TestOrchestrationFieldContracts:
def test_agent_request_orchestration_fields_roundtrip(self):
req = AgentRequest(
question="Test question",
user="testuser",
collection="default",
correlation_id="corr-123",
parent_session_id="parent-sess",
@ -42,7 +41,6 @@ class TestOrchestrationFieldContracts:
def test_agent_request_orchestration_fields_default_empty(self):
req = AgentRequest(
question="Test question",
user="testuser",
)
assert req.correlation_id == ""
@ -82,7 +80,6 @@ class TestSubagentCompletionStepContract:
)
req = AgentRequest(
question="goal",
user="testuser",
correlation_id="corr-123",
history=[step],
)
@ -126,7 +123,6 @@ class TestSynthesisStepContract:
req = AgentRequest(
question="Original question",
user="testuser",
pattern="supervisor",
correlation_id="",
session_id="parent-sess",

View file

@ -22,7 +22,6 @@ class TestRowsCassandraContracts:
# Create test object with all required fields
test_metadata = Metadata(
id="test-doc-001",
user="test_user",
collection="test_collection",
)
@ -47,7 +46,6 @@ class TestRowsCassandraContracts:
# Verify metadata structure
assert hasattr(test_object.metadata, 'id')
assert hasattr(test_object.metadata, 'user')
assert hasattr(test_object.metadata, 'collection')
# Verify types
@ -150,7 +148,6 @@ class TestRowsCassandraContracts:
original = ExtractedObject(
metadata=Metadata(
id="serial-001",
user="test_user",
collection="test_coll",
),
schema_name="test_schema",
@ -168,7 +165,6 @@ class TestRowsCassandraContracts:
# Verify round-trip
assert decoded.metadata.id == original.metadata.id
assert decoded.metadata.user == original.metadata.user
assert decoded.metadata.collection == original.metadata.collection
assert decoded.schema_name == original.schema_name
assert decoded.values == original.values
@ -228,8 +224,7 @@ class TestRowsCassandraContracts:
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="meta-001",
user="user123", # -> keyspace
id="meta-001", # -> keyspace
collection="coll456", # -> partition key
),
schema_name="table789", # -> table name
@ -242,7 +237,6 @@ class TestRowsCassandraContracts:
# - metadata.user -> Cassandra keyspace
# - schema_name -> Cassandra table
# - metadata.collection -> Part of primary key
assert test_obj.metadata.user # Required for keyspace
assert test_obj.schema_name # Required for table
assert test_obj.metadata.collection # Required for partition key
@ -256,7 +250,6 @@ class TestRowsCassandraContractsBatch:
# Create test object with multiple values in batch
test_metadata = Metadata(
id="batch-doc-001",
user="test_user",
collection="test_collection",
)
@ -302,7 +295,6 @@ class TestRowsCassandraContractsBatch:
"""Test empty batch ExtractedObject contract"""
test_metadata = Metadata(
id="empty-batch-001",
user="test_user",
collection="test_collection",
)
@ -324,7 +316,6 @@ class TestRowsCassandraContractsBatch:
"""Test single-item batch (backward compatibility) contract"""
test_metadata = Metadata(
id="single-batch-001",
user="test_user",
collection="test_collection",
)
@ -353,7 +344,6 @@ class TestRowsCassandraContractsBatch:
original = ExtractedObject(
metadata=Metadata(
id="batch-serial-001",
user="test_user",
collection="test_coll",
),
schema_name="test_schema",
@ -375,7 +365,6 @@ class TestRowsCassandraContractsBatch:
# Verify round-trip for batch
assert decoded.metadata.id == original.metadata.id
assert decoded.metadata.user == original.metadata.user
assert decoded.metadata.collection == original.metadata.collection
assert decoded.schema_name == original.schema_name
assert len(decoded.values) == len(original.values)
@ -425,8 +414,7 @@ class TestRowsCassandraContractsBatch:
# 3. Be stored in the same keyspace (user)
test_metadata = Metadata(
id="partition-test-001",
user="consistent_user", # Same keyspace
id="partition-test-001", # Same keyspace
collection="consistent_collection", # Same partition
)
@ -443,7 +431,6 @@ class TestRowsCassandraContractsBatch:
)
# Verify consistency contract
assert batch_object.metadata.user # Must have user for keyspace
assert batch_object.metadata.collection # Must have collection for partition key
# Verify unique primary keys in batch

View file

@ -21,29 +21,25 @@ class TestRowsGraphQLQueryContracts:
"""Test RowsQueryRequest schema structure and required fields"""
# Create test request with all required fields
test_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name email } }',
variables={"status": "active", "limit": "10"},
operation_name="GetCustomers"
)
# Verify all required fields are present
assert hasattr(test_request, 'user')
assert hasattr(test_request, 'collection')
assert hasattr(test_request, 'collection')
assert hasattr(test_request, 'query')
assert hasattr(test_request, 'variables')
assert hasattr(test_request, 'operation_name')
# Verify field types
assert isinstance(test_request.user, str)
assert isinstance(test_request.collection, str)
assert isinstance(test_request.query, str)
assert isinstance(test_request.variables, dict)
assert isinstance(test_request.operation_name, str)
# Verify content
assert test_request.user == "test_user"
assert test_request.collection == "test_collection"
assert "customers" in test_request.query
assert test_request.variables["status"] == "active"
@ -53,15 +49,13 @@ class TestRowsGraphQLQueryContracts:
"""Test RowsQueryRequest with minimal required fields"""
# Create request with only essential fields
minimal_request = RowsQueryRequest(
user="user",
collection="collection",
query='{ test }',
variables={},
operation_name=""
)
# Verify minimal request is valid
assert minimal_request.user == "user"
assert minimal_request.collection == "collection"
assert minimal_request.query == '{ test }'
assert minimal_request.variables == {}
@ -187,22 +181,20 @@ class TestRowsGraphQLQueryContracts:
"""Test that request/response can be serialized/deserialized correctly"""
# Create original request
original_request = RowsQueryRequest(
user="serialization_test",
collection="test_data",
query='{ orders(limit: 5) { id total customer { name } } }',
variables={"limit": "5", "status": "active"},
operation_name="GetRecentOrders"
)
# Test request serialization using Pulsar schema
request_schema = AvroSchema(RowsQueryRequest)
# Encode and decode request
encoded_request = request_schema.encode(original_request)
decoded_request = request_schema.decode(encoded_request)
# Verify request round-trip
assert decoded_request.user == original_request.user
assert decoded_request.collection == original_request.collection
assert decoded_request.query == original_request.query
assert decoded_request.variables == original_request.variables
@ -245,7 +237,7 @@ class TestRowsGraphQLQueryContracts:
"""Test supported GraphQL query formats"""
# Test basic query
basic_query = RowsQueryRequest(
user="test", collection="test", query='{ customers { id } }',
collection="test", query='{ customers { id } }',
variables={}, operation_name=""
)
assert "customers" in basic_query.query
@ -254,7 +246,7 @@ class TestRowsGraphQLQueryContracts:
# Test query with variables
parameterized_query = RowsQueryRequest(
user="test", collection="test",
collection="test",
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
variables={"status": "active", "limit": "10"},
operation_name="GetCustomers"
@ -266,7 +258,7 @@ class TestRowsGraphQLQueryContracts:
# Test complex nested query
nested_query = RowsQueryRequest(
user="test", collection="test",
collection="test",
query='''
{
customers(limit: 10) {
@ -297,7 +289,7 @@ class TestRowsGraphQLQueryContracts:
# This test verifies the current contract, though ideally we'd support all JSON types
variables_test = RowsQueryRequest(
user="test", collection="test", query='{ test }',
collection="test", query='{ test }',
variables={
"string_var": "test_value",
"numeric_var": "123", # Numbers as strings due to Map(String()) limitation
@ -318,22 +310,18 @@ class TestRowsGraphQLQueryContracts:
def test_cassandra_context_fields_contract(self):
"""Test that request contains necessary fields for Cassandra operations"""
# Verify request has fields needed for Cassandra keyspace/table targeting
# Verify request has fields needed for partition key targeting
request = RowsQueryRequest(
user="keyspace_name", # Maps to Cassandra keyspace
collection="partition_collection", # Used in partition key
query='{ objects { id } }',
variables={}, operation_name=""
)
# These fields are required for proper Cassandra operations
assert request.user # Required for keyspace identification
assert request.collection # Required for partition key
# Required for partition key
assert request.collection
# Verify field naming follows TrustGraph patterns (matching other query services)
# This matches TriplesQueryRequest, DocumentEmbeddingsRequest patterns
assert hasattr(request, 'user') # Same as TriplesQueryRequest.user
assert hasattr(request, 'collection') # Same as TriplesQueryRequest.collection
assert hasattr(request, 'collection')
def test_graphql_extensions_contract(self):
"""Test GraphQL extensions field format and usage"""
@ -405,7 +393,7 @@ class TestRowsGraphQLQueryContracts:
# Request to execute specific operation
multi_op_request = RowsQueryRequest(
user="test", collection="test",
collection="test",
query=multi_op_query,
variables={},
operation_name="GetCustomers"
@ -418,7 +406,7 @@ class TestRowsGraphQLQueryContracts:
# Test single operation (operation_name optional)
single_op_request = RowsQueryRequest(
user="test", collection="test",
collection="test",
query='{ customers { id } }',
variables={}, operation_name=""
)

View file

@ -0,0 +1,74 @@
"""
Contract tests for schema dataclass field sets.
These pin the *field names* of small, widely-constructed schema dataclasses
so that any rename, removal, or accidental addition fails CI loudly instead
of waiting for a runtime TypeError on the next websocket message.
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
field but several call sites kept passing `Metadata(metadata=...)`. The bug
was only discovered when a websocket import dispatcher received its first
real message in production. A trivial structural assertion of the kind
below would have caught it at unit-test time.
Add to this file whenever a schema rename burns you. The cost of a frozen
field set is a one-line update when you intentionally evolve the schema; the
benefit is that every call site is forced to come along for the ride.
"""
import dataclasses
import pytest
from trustgraph.schema import (
Metadata,
EntityContext,
EntityEmbeddings,
ChunkEmbeddings,
)
def _field_names(dc):
return {f.name for f in dataclasses.fields(dc)}
@pytest.mark.contract
class TestSchemaFieldContracts:
"""Pin the field set of dataclasses that get constructed all over the
codebase. If you intentionally change one of these, update the
expected set in the same commit that diff will surface every call
site that needs to come along."""
def test_metadata_fields(self):
# NOTE: there is no `metadata` field. A previous regression
# constructed Metadata(metadata=...) and crashed at runtime.
# `user` was also dropped in the workspace refactor — workspace
# now flows via flow.workspace, not via message payload.
assert _field_names(Metadata) == {
"id",
"root",
"collection",
}
def test_entity_embeddings_fields(self):
# NOTE: the embedding field is `vector` (singular, list[float]).
# There is no `vectors` field. Several call sites historically
# passed `vectors=` and crashed at runtime.
assert _field_names(EntityEmbeddings) == {
"entity",
"vector",
"chunk_id",
}
def test_chunk_embeddings_fields(self):
# Same `vector` (singular) convention as EntityEmbeddings.
assert _field_names(ChunkEmbeddings) == {
"chunk_id",
"vector",
}
def test_entity_context_fields(self):
assert _field_names(EntityContext) == {
"entity",
"context",
"chunk_id",
}

View file

@ -93,7 +93,6 @@ class TestStructuredDataSchemaContracts:
# Arrange
metadata = Metadata(
id="structured-data-001",
user="test_user",
collection="test_collection",
)
@ -118,7 +117,6 @@ class TestStructuredDataSchemaContracts:
# Arrange
metadata = Metadata(
id="extracted-obj-001",
user="test_user",
collection="test_collection",
)
@ -143,7 +141,6 @@ class TestStructuredDataSchemaContracts:
# Arrange
metadata = Metadata(
id="extracted-batch-001",
user="test_user",
collection="test_collection",
)
@ -177,7 +174,6 @@ class TestStructuredDataSchemaContracts:
# Arrange
metadata = Metadata(
id="extracted-empty-001",
user="test_user",
collection="test_collection",
)
@ -277,7 +273,6 @@ class TestStructuredEmbeddingsContracts:
# Arrange
metadata = Metadata(
id="struct-embed-001",
user="test_user",
collection="test_collection",
)
@ -308,7 +303,7 @@ class TestStructuredDataSerializationContracts:
def test_structured_data_submission_serialization(self):
"""Test StructuredDataSubmission serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col")
metadata = Metadata(id="test", collection="col")
submission_data = {
"metadata": metadata,
"format": "json",
@ -323,7 +318,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_serialization(self):
"""Test ExtractedObject serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col")
metadata = Metadata(id="test", collection="col")
object_data = {
"metadata": metadata,
"schema_name": "test_schema",
@ -373,7 +368,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_batch_serialization(self):
"""Test ExtractedObject batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col")
metadata = Metadata(id="test", collection="col")
batch_object_data = {
"metadata": metadata,
"schema_name": "test_schema",
@ -392,7 +387,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_empty_batch_serialization(self):
"""Test ExtractedObject empty batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col")
metadata = Metadata(id="test", collection="col")
empty_batch_data = {
"metadata": metadata,
"schema_name": "test_schema",

View file

@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="4",
end_of_message=True,
end_of_dialog=True,
@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="I need to solve this.",
end_of_message=True,
end_of_dialog=False,
@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message
thought_response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="Processing...",
end_of_message=True,
end_of_dialog=False,
@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
# Test observation message
observation_response = AgentResponse(
chunk_type="observation",
message_type="observation",
content="Result found",
end_of_message=True,
end_of_dialog=False,
@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
# Streaming format with end_of_dialog=True
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,

View file

@ -418,55 +418,55 @@ def sample_streaming_agent_response():
"""Sample streaming agent response chunks"""
return [
{
"chunk_type": "thought",
"message_type": "thought",
"content": "I need to search",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"message_type": "thought",
"content": " for information",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"message_type": "thought",
"content": " about machine learning.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "action",
"message_type": "action",
"content": "knowledge_query",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"message_type": "observation",
"content": "Machine learning is",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"message_type": "observation",
"content": " a subset of AI.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": "Machine learning",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": " is a subset",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": " of artificial intelligence.",
"end_of_message": True,
"end_of_dialog": True
@ -494,10 +494,10 @@ def streaming_chunk_collector():
"""Concatenate all chunk content"""
return "".join(self.chunks)
def get_chunk_types(self):
def get_message_types(self):
"""Get list of chunk types if chunks are dicts"""
if self.chunks and isinstance(self.chunks[0], dict):
return [c.get("chunk_type") for c in self.chunks]
return [c.get("message_type") for c in self.chunks]
return []
def verify_streaming_protocol(self):

View file

@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client
prompt_client = AsyncMock()
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for information about machine learning
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
)
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
# Mock text completion client
text_completion_client = AsyncMock()
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
text_completion_client.question.return_value = PromptResult(
response_type="text",
text="Machine learning involves algorithms that improve through experience."
)
# Mock MCP tool client
mcp_tool_client = AsyncMock()
@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
)
question = "What is machine learning?"
history = []
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence."""
)
question = "What is machine learning?"
history = []
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios:
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: I need to use {tool_name}
Action: {tool_name}
Args: {{
"question": "test question"
}}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to use an unknown tool
Action: unknown_tool
Args: {
"param": "value"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -308,11 +327,13 @@ Args: {
think_callback = AsyncMock()
observe_callback = AsyncMock()
# Act & Assert
with pytest.raises(Exception) as exc_info:
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
assert "Tool execution failed" in str(exc_info.value)
# Act - tool errors are now caught and returned as observations
result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
# Assert - error captured on the action, not raised
assert result.tool_error is not None
assert "Tool execution failed" in result.tool_error
assert "Error:" in result.observation
@pytest.mark.asyncio
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
@ -321,11 +342,14 @@ Args: {
question = "Find information about AI and summarize it"
# Mock multi-step reasoning
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for AI information first
Action: knowledge_query
Args: {
"question": "What is artificial intelligence?"
}"""
)
# Act
action = await agent_manager.reason(question, [], mock_flow_context)
@ -372,9 +396,12 @@ Args: {
# Format arguments as JSON
import json
args_json = json.dumps(test_case['arguments'], indent=4)
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']}
Args: {args_json}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -507,15 +534,17 @@ Args: {
]
for test_case in test_cases:
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=test_case["response"]
)
if test_case["error_contains"]:
# Should raise an error
with pytest.raises(RuntimeError) as exc_info:
await agent_manager.reason("test question", [], mock_flow_context)
assert "Failed to parse agent response" in str(exc_info.value)
assert test_case["error_contains"] in str(exc_info.value)
# Parse errors now return an Action with tool_error
result = await agent_manager.reason("test question", [], mock_flow_context)
assert isinstance(result, Action)
assert result.name == "__parse_error__"
assert result.tool_error is not None
else:
# Should succeed
action = await agent_manager.reason("test question", [], mock_flow_context)
@ -527,13 +556,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing"""
# Test response with markdown code blocks
mock_flow_context("prompt-request").agent_react.return_value = """```
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""```
Thought: I need to search for information
Action: knowledge_query
Args: {
"question": "What is AI?"
}
```"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -541,15 +573,18 @@ Args: {
assert action.name == "knowledge_query"
# Test response with extra whitespace
mock_flow_context("prompt-request").agent_react.return_value = """
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""
Thought: I need to think about this
Action: knowledge_query
Thought: I need to think about this
Action: knowledge_query
Args: {
"question": "test"
}
"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -560,7 +595,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers"""
# Multi-line thought
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to consider multiple factors:
1. The user's question is complex
2. I should search for comprehensive information
3. This requires using the knowledge query tool
@ -568,6 +605,7 @@ Action: knowledge_query
Args: {
"question": "complex query"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -575,13 +613,16 @@ Args: {
assert "knowledge query tool" in action.thought
# Multi-line final answer
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer:
1. First point about the topic
2. Second point with details
3. Final conclusion
This covers all aspects of the question."""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -593,13 +634,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped)
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Processing special characters
Action: knowledge_query
Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -608,7 +652,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Complex arguments
Action: web_search
Args: {
"query": "test",
@ -621,6 +667,7 @@ Args: {
}
}
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -632,7 +679,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content"""
# Final answer with JSON content
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide the data in JSON format
Final Answer: {
"result": "success",
"data": {
@ -642,6 +691,7 @@ Final Answer: {
},
"confidence": 0.95
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -792,11 +842,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()

View file

@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
assert_callback_invoked,
assert_chunk_types_valid,
assert_message_types_valid,
)
@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming
return client
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
return response
return response
return PromptResult(response_type="text", text=response)
return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Error
)
from trustgraph.agent.react.service import Processor
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -57,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
"""Test basic agent integration with structured query tool"""
# Arrange - Load tool configuration
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Create agent request
request = AgentRequest(
@ -65,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -95,11 +95,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from New York using structured query
Action: structured-query
Args: {
"question": "Find all customers from New York"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -115,6 +118,7 @@ Args: {
# Mock flow parameter in agent_processor.on_request
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -142,14 +146,13 @@ Args: {
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
"""Test agent handling of structured query errors"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -173,11 +176,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query for a table that might not exist
Action: structured-query
Args: {
"question": "Find data from a table that doesn't exist"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -192,6 +198,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -214,14 +221,13 @@ Args: {
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
"""Test agent using structured query in multi-step reasoning"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -250,11 +256,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from California first
Action: structured-query
Args: {
"question": "Find all customers from California"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -269,6 +278,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -303,14 +313,13 @@ Args: {
}
}
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
request = AgentRequest(
question="Query the sales data for recent transactions.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -339,11 +348,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query the sales data
Action: structured-query
Args: {
"question": "Query the sales data for recent transactions"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -358,6 +370,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -381,10 +394,10 @@ Args: {
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
"""Test that structured query tool arguments are properly validated"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Check that the tool was registered with correct arguments
tools = agent_processor.agent.tools
tools = agent_processor.agents["default"].tools
assert "structured-query" in tools
structured_tool = tools["structured-query"]
@ -401,14 +414,13 @@ Args: {
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
"""Test that structured query results are properly formatted for agent consumption"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Get customer information and format it nicely.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -447,11 +459,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to get customer information
Action: structured-query
Args: {
"question": "Get customer information and format it nicely"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -466,6 +481,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)

View file

@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
# Create a mock message to trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
# This should create TrustGraph with environment config
await processor.store_triples(mock_message)
await processor.store_triples('test_user', mock_message)
# Verify Cluster was created with correct hosts
mock_cluster.assert_called_once()
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_user', mock_message)
# Should use CLI parameters, not environment
mock_cluster.assert_called_once()
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
# Mock query to trigger TrustGraph creation
mock_query = MagicMock()
mock_query.user = 'default_user'
mock_query.collection = 'default_collection'
mock_query.s = None
mock_query.p = None
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
mock_tg_instance.get_all.return_value = []
processor.tg = mock_tg_instance
await processor.query_triples(mock_query)
await processor.query_triples('default_user', mock_query)
# Should use defaults
mock_cluster.assert_called_once()
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'legacy_user'
mock_message.metadata.collection = 'legacy_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('legacy_user', mock_message)
# Should use defaults since old parameters are not recognized
mock_cluster.assert_called_once()
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'precedence_user'
mock_message.metadata.collection = 'precedence_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('precedence_user', mock_message)
# Should use new parameters, not old ones
mock_cluster.assert_called_once()
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'single_user'
mock_message.metadata.collection = 'single_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('single_user', mock_message)
# Single host should be converted to list
mock_cluster.assert_called_once()

View file

@ -115,7 +115,7 @@ class TestCassandraIntegration:
# Create test message
storage_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
metadata=Metadata(collection="testcol"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.org/person1"),
@ -178,7 +178,7 @@ class TestCassandraIntegration:
# Store test data for querying
query_test_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
metadata=Metadata(collection="testcol"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.org/alice"),
@ -212,7 +212,6 @@ class TestCassandraIntegration:
p=None, # None for wildcard
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
s_results = await query_processor.query_triples(s_query)
@ -232,7 +231,6 @@ class TestCassandraIntegration:
p=Term(type=IRI, iri="http://example.org/knows"),
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
p_results = await query_processor.query_triples(p_query)
@ -259,7 +257,7 @@ class TestCassandraIntegration:
# Create multiple coroutines for concurrent storage
async def store_person_data(person_id, name, age, department):
message = Triples(
metadata=Metadata(user="concurrent_test", collection="people"),
metadata=Metadata(collection="people"),
triples=[
Triple(
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
@ -329,7 +327,7 @@ class TestCassandraIntegration:
# Create a knowledge graph about a company
company_graph = Triples(
metadata=Metadata(user="integration_test", collection="company"),
metadata=Metadata(collection="company"),
triples=[
# People and their types
Triple(

View file

@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = (
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -93,7 +99,6 @@ class TestDocumentRagIntegration:
# Act
result = await document_rag.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit
)
@ -104,7 +109,6 @@ class TestDocumentRagIntegration:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
@ -119,6 +123,7 @@ class TestDocumentRagIntegration:
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@ -131,7 +136,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
@ -152,7 +161,8 @@ class TestDocumentRagIntegration:
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
@ -266,14 +276,12 @@ class TestDocumentRagIntegration:
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
@ -341,6 +349,5 @@ class TestDocumentRagIntegration:
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -104,7 +107,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
@ -119,11 +121,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks
result_text, usage = result
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert result_text == full_from_chunks
# Verify content is reasonable
assert len(result) > 0
assert len(result_text) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@ -137,7 +140,6 @@ class TestDocumentRagStreaming:
# Act - Non-streaming
non_streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=False
@ -151,7 +153,6 @@ class TestDocumentRagStreaming:
streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=True,
@ -159,9 +160,11 @@ class TestDocumentRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
non_streaming_text, _ = non_streaming_result
streaming_text, _ = streaming_result
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@ -172,7 +175,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -180,8 +182,9 @@ class TestDocumentRagStreaming:
)
# Assert
result_text, usage = result
assert callback.call_count > 0
assert result is not None
assert result_text is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -193,7 +196,6 @@ class TestDocumentRagStreaming:
# Arrange & Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -202,7 +204,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
result_text, usage = result
assert isinstance(result_text, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@ -215,7 +218,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
@ -223,7 +225,8 @@ class TestDocumentRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
@pytest.mark.asyncio
@ -238,7 +241,6 @@ class TestDocumentRagStreaming:
with pytest.raises(Exception) as exc_info:
await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -263,7 +265,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=limit,
streaming=True,
@ -271,7 +272,8 @@ class TestDocumentRagStreaming:
)
# Assert
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly
@ -290,7 +292,6 @@ class TestDocumentRagStreaming:
# Act
await document_rag_streaming.query(
query="test query",
user=user,
collection=collection,
doc_limit=10,
streaming=True,
@ -299,5 +300,4 @@ class TestDocumentRagStreaming:
# Assert - Verify user/collection were passed to document embeddings client
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection

View file

@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
return PromptResult(
response_type="text",
text=(
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
)
return ""
return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@ -142,7 +146,6 @@ class TestGraphRagIntegration:
# Act
response = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
@ -159,7 +162,6 @@ class TestGraphRagIntegration:
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
@ -169,6 +171,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
@ -199,7 +202,6 @@ class TestGraphRagIntegration:
# Act
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection",
entity_limit=config["entity_limit"],
triple_limit=config["triple_limit"]
@ -219,7 +221,6 @@ class TestGraphRagIntegration:
with pytest.raises(Exception) as exc_info:
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection"
)
@ -242,7 +243,6 @@ class TestGraphRagIntegration:
# Act
response = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection",
explain_callback=collect_provenance
)
@ -262,7 +262,6 @@ class TestGraphRagIntegration:
# First query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
@ -272,7 +271,6 @@ class TestGraphRagIntegration:
# Second identical query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
@ -284,26 +282,27 @@ class TestGraphRagIntegration:
assert second_call_count >= 0 # Should complete without errors
@pytest.mark.asyncio
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different users/collections are properly isolated"""
async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different collections propagate through to the embeddings query.
Workspace isolation is enforced by flow.workspace at the service
boundary not by parameters on GraphRag.query so this test
verifies collection routing only.
"""
# Arrange
query = "test query"
user1, collection1 = "user1", "collection1"
user2, collection2 = "user2", "collection2"
collection1 = "collection1"
collection2 = "collection2"
# Act
await graph_rag.query(query=query, user=user1, collection=collection1)
await graph_rag.query(query=query, user=user2, collection=collection2)
await graph_rag.query(query=query, collection=collection1)
await graph_rag.query(query=query, collection=collection2)
# Assert - Both users should have separate queries
# Assert - Each call propagated its collection
assert mock_graph_embeddings_client.query.call_count == 2
# Verify first call
first_call = mock_graph_embeddings_client.query.call_args_list[0]
assert first_call.kwargs['user'] == user1
assert first_call.kwargs['collection'] == collection1
# Verify second call
second_call = mock_graph_embeddings_client.query.call_args_list[1]
assert second_call.kwargs['user'] == user2
assert second_call.kwargs['collection'] == collection2

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
return full_text
return ""
return PromptResult(response_type="text", text=full_text)
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -115,7 +116,6 @@ class TestGraphRagStreaming:
# Act - query() returns response, provenance via callback
response = await graph_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collector.collect,
@ -123,6 +123,7 @@ class TestGraphRagStreaming:
)
# Assert
response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@ -152,7 +153,6 @@ class TestGraphRagStreaming:
# Act - Non-streaming
non_streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=False
)
@ -165,16 +165,17 @@ class TestGraphRagStreaming:
streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=True,
chunk_callback=collect
)
# Assert - Results should be equivalent
assert streaming_response == non_streaming_response
non_streaming_text, _ = non_streaming_response
streaming_text, _ = streaming_response
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_response
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -185,7 +186,6 @@ class TestGraphRagStreaming:
# Act
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -205,7 +205,6 @@ class TestGraphRagStreaming:
# Arrange & Act
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=None # No callback provided
@ -213,7 +212,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error
assert response is not None
assert isinstance(response, str)
response_text, usage = response
assert isinstance(response_text, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
@ -226,7 +226,6 @@ class TestGraphRagStreaming:
# Act
response = await graph_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -248,7 +247,6 @@ class TestGraphRagStreaming:
with pytest.raises(Exception) as exc_info:
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -268,7 +266,6 @@ class TestGraphRagStreaming:
# Act
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=entity_limit,
triple_limit=triple_limit,

View file

@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
triples_obj = Triples(
metadata=Metadata(
id=f"export-msg-{i}",
user=msg_data["metadata"]["user"],
collection=msg_data["metadata"]["collection"],
),
triples=to_subgraph(msg_data["triples"]),

View file

@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction
prompt_client = AsyncMock()
prompt_client.extract_definitions.return_value = [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
prompt_client.extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
)
# Mock prompt client for relationships extraction
prompt_client.extract_relationships.return_value = [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
prompt_client.extract_relationships.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
)
# Mock producers for output streams
triples_producer = AsyncMock()
@ -90,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
return Chunk(
metadata=Metadata(
id="doc-123",
user="test_user",
collection="test_collection",
),
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
@ -240,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
metadata = Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
)
@ -298,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
metadata = Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
)
@ -368,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_triples = Triples(
metadata=Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
),
triples=[
@ -383,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock()
mock_msg.value.return_value = sample_triples
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act
await processor.on_triples(mock_msg, None, None)
await processor.on_triples(mock_msg, None, mock_flow)
# Assert
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
@pytest.mark.asyncio
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
@ -400,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_embeddings = GraphEmbeddings(
metadata=Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
),
entities=[
@ -414,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock()
mock_msg.value.return_value = sample_embeddings
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act
await processor.on_graph_embeddings(mock_msg, None, None)
await processor.on_graph_embeddings(mock_msg, None, mock_flow)
# Assert
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
@pytest.mark.asyncio
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
@ -489,7 +497,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = []
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[]
)
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -510,7 +521,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="text",
text="invalid format"
) # Should be jsonl with objects list
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -528,16 +542,19 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
)
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),
metadata=Metadata(id="test", collection="collection"),
chunk=b"Test chunk"
)
@ -564,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
large_chunk_batch = [
Chunk(
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
metadata=Metadata(id=f"doc-{i}", collection="collection"),
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
)
for i in range(100) # Large batch
@ -601,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
original_metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -630,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
assert triples_call.metadata.id == "test-doc-123"
assert triples_call.metadata.user == "test_user"
assert triples_call.metadata.collection == "test_collection"
assert entity_contexts_call.metadata.id == "test-doc-123"
assert entity_contexts_call.metadata.user == "test_user"
assert entity_contexts_call.metadata.collection == "test_collection"

View file

@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
)
# Set up schemas
proc.schemas = sample_schemas
proc.schemas = {"default": dict(sample_schemas)}
# Mock the client method
proc.client = MagicMock()
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
}
# Act - Update configuration
await integration_processor.on_schema_config(new_schema_config, "v2")
await integration_processor.on_schema_config("default", new_schema_config, "v2")
# Arrange - Test query using new schema
request = QuestionToStructuredQueryRequest(
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
await integration_processor.on_message(msg, consumer, flow)
# Assert
assert "inventory" in integration_processor.schemas
assert "inventory" in integration_processor.schemas["default"]
response_call = flow_response.send.call_args
response = response_call[0][0]
assert response.detected_schemas == ["inventory"]
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
graphql_generation_template="custom-graphql-generator"
)
custom_processor.schemas = sample_schemas
custom_processor.schemas = {"default": dict(sample_schemas)}
custom_processor.client = MagicMock()
request = QuestionToStructuredQueryRequest(
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
)
integration_processor.schemas.update(large_schema_set)
integration_processor.schemas["default"].update(large_schema_set)
request = QuestionToStructuredQueryRequest(
question="Show me data from table_05 and table_12",
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
)
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records":
if "john" in text.lower():
return [
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
)
elif "jane" in text.lower():
return [
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
)
else:
return []
return PromptResult(response_type="jsonl", objects=[])
elif schema_name == "product_catalog":
if "laptop" in text.lower():
return [
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
)
elif "book" in text.lower():
return [
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
)
else:
return []
return []
return PromptResult(response_type="jsonl", objects=[])
return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects
@ -172,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.mark.asyncio
@ -184,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Act
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
# Verify customer schema
customer_schema = processor.schemas["customer_records"]
customer_schema = ws_schemas["customer_records"]
assert customer_schema.name == "customer_records"
assert len(customer_schema.fields) == 4
# Verify product schema
product_schema = processor.schemas["product_catalog"]
product_schema = ws_schemas["product_catalog"]
assert product_schema.name == "product_catalog"
assert len(product_schema.fields) == 4
@ -224,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic customer data chunk
metadata = Metadata(
id="customer-doc-001",
user="integration_test",
collection="test_documents",
)
@ -291,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic product data chunk
metadata = Metadata(
id="product-doc-001",
user="integration_test",
collection="test_documents",
)
@ -355,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create multiple test chunks
chunks_data = [
@ -369,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
for chunk_id, text in chunks_data:
metadata = Metadata(
id=chunk_id,
user="concurrent_test",
collection="test_collection",
)
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
@ -418,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
"customer_records": integration_config["schema"]["customer_records"]
}
}
await processor.on_schema_config(initial_config, version=1)
assert len(processor.schemas) == 1
assert "customer_records" in processor.schemas
assert "product_catalog" not in processor.schemas
await processor.on_schema_config("default", initial_config, version=1)
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 1
assert "customer_records" in ws_schemas
assert "product_catalog" not in ws_schemas
# Act - Reload with full configuration
await processor.on_schema_config(integration_config, version=2)
await processor.on_schema_config("default", integration_config, version=2)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
@pytest.mark.asyncio
async def test_error_resilience_integration(self, integration_config):
@ -461,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
failing_flow.side_effect = failing_context_router
failing_flow.workspace = "default"
processor.flow = failing_flow
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test")
metadata = Metadata(id="error-test", collection="test")
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
mock_msg = MagicMock()
@ -497,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create chunk with rich metadata
original_metadata = Metadata(
id="metadata-test-chunk",
user="test_user",
collection="test_collection",
)
@ -531,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
assert extracted_obj is not None
# Verify metadata propagation
assert extracted_obj.metadata.user == "test_user"
assert extracted_obj.metadata.collection == "test_collection"
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming
text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600):
"""Simulate streaming text completion"""
if request.streaming and recipient:
# Simulate streaming chunks
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
# Streaming chunks to send
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks):
is_final = (i == len(chunks) - 1)
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
"""Simulate streaming text completion via text_completion_stream"""
for i, chunk_text in enumerate(chunks):
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=True
))
end_of_stream=False
)
await handler(response)
text_completion_client.request = streaming_request
# Send final empty chunk with end_of_stream
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
return TextCompletionResult(
text=None,
in_token=10,
out_token=9,
model="test-model",
)
async def non_streaming_text_completion(system, prompt, timeout=600):
"""Simulate non-streaming text completion"""
full_text = "Machine learning is a field of artificial intelligence."
return TextCompletionResult(
text=full_text,
in_token=10,
out_token=9,
model="test-model",
)
text_completion_client.text_completion_stream = AsyncMock(
side_effect=streaming_text_completion_stream
)
text_completion_client.text_completion = AsyncMock(
side_effect=non_streaming_text_completion
)
# Mock response producer
response_producer = AsyncMock()
@ -68,6 +87,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.fixture
@ -90,7 +110,7 @@ class TestPromptStreaming:
def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support"""
processor = MagicMock()
processor.manager = mock_prompt_manager
processor.managers = {"default": mock_prompt_manager}
processor.config_key = "prompt"
# Bind the actual on_request method
@ -156,14 +176,6 @@ class TestPromptStreaming:
consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
@ -218,17 +230,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600):
if recipient:
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
async def failing_stream(system, prompt, handler, timeout=600):
raise RuntimeError("Text completion error")
text_completion_client.request = failing_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=failing_stream
)
# Mock response producer to capture error response
response_producer = AsyncMock()
@ -242,6 +249,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",
@ -255,22 +263,15 @@ class TestPromptStreaming:
consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise
# Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent
assert response_producer.send.call_count > 0
# Check that at least one response contains an error
error_sent = False
for call in response_producer.send.call_args_list:
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
# Assert - error response was sent
calls = response_producer.send.call_args_list
assert len(calls) > 0
error_response = calls[-1].args[0]
assert error_response.error is not None
assert "Text completion error" in error_response.error.message
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@ -315,21 +316,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600):
if request.streaming and recipient:
# Send empty chunk followed by final marker
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
async def empty_streaming(system, prompt, handler, timeout=600):
# Send empty chunk followed by final marker
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = empty_streaming_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=empty_streaming
)
response_producer = AsyncMock()
def context_router(service_name):
@ -341,6 +343,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",
@ -401,4 +404,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence"
assert full_text == "Machine learning is a field of artificial intelligence."

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol:
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
return PromptResult(response_type="text", text="")
else:
return "The answer is here."
return ""
return PromptResult(response_type="text", text="The answer is here.")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -237,11 +232,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
return ""
return PromptResult(response_type="text", text="")
else:
return "Document summary."
return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -265,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -288,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -312,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -334,17 +328,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
return PromptResult(response_type="text", text="")
else:
return "textmore"
return ""
return PromptResult(response_type="text", text="textmore")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties

View file

@ -14,10 +14,35 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@pytest.fixture(autouse=True)
def patch_async_execute(self):
"""Route async_execute through session.execute so the mock's
side_effect handles all CQL (DDL and DML) uniformly and every
call lands in mock_session.execute.call_args_list."""
async def _fake(session, query, params=None):
session.execute(query, params)
return []
with patch(
'trustgraph.storage.rows.cassandra.write.async_execute',
new=_fake,
):
yield
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
@ -111,14 +136,13 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
await processor.on_schema_config("default", config, version=1)
assert "customer_records" in processor.schemas["default"]
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
),
schema_name="customer_records",
@ -135,7 +159,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify Cassandra interactions
assert mock_cluster.connect.called
@ -144,7 +168,7 @@ class TestRowsCassandraIntegration:
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
assert "default" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list
@ -195,12 +219,12 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
await processor.on_schema_config("default", config, version=1)
assert len(processor.schemas["default"]) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog"),
metadata=Metadata(id="p1", collection="catalog"),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
@ -208,7 +232,7 @@ class TestRowsCassandraIntegration:
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales"),
metadata=Metadata(id="o1", collection="sales"),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
@ -219,7 +243,7 @@ class TestRowsCassandraIntegration:
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
@ -242,18 +266,20 @@ class TestRowsCassandraIntegration:
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
processor.schemas["default"] = {
"indexed_data": RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
metadata=Metadata(id="t1", collection="test"),
schema_name="indexed_data",
values=[{
"id": "123",
@ -268,7 +294,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -328,13 +354,12 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
),
schema_name="batch_customers",
@ -362,7 +387,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
@ -382,14 +407,16 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
processor.schemas["default"] = {
"empty_test": RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty"),
metadata=Metadata(id="empty-1", collection="empty"),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
@ -399,7 +426,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
@ -414,17 +441,19 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
processor.schemas["default"] = {
"map_test": RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
metadata=Metadata(id="t1", collection="test"),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
@ -434,7 +463,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -459,16 +488,18 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
processor.schemas["default"] = {
"partition_test": RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection"),
metadata=Metadata(id="t1", collection="my_collection"),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,
@ -478,7 +509,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list

View file

@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
"""Test schema configuration loading and GraphQL schema generation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Verify schemas were loaded
assert len(processor.schemas) == 2
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
"""Test Cassandra connection and dynamic table creation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra
processor.connect_cassandra()
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
"""Test inserting data and querying via GraphQL"""
# Load schema and connect
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Setup test data
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
"""Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup)
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "test_user"
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_error_handling(self, processor, sample_schema_config):
"""Test GraphQL error handling for invalid queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Test invalid field query
invalid_query = '''
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_message_processing_integration(self, processor, sample_schema_config):
"""Test full message processing workflow"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create mock message
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_concurrent_queries(self, processor, sample_schema_config):
"""Test handling multiple concurrent GraphQL queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create multiple query tasks
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(initial_config, version=1)
await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1
assert "simple" in processor.schemas
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(updated_config, version=2)
await processor.on_schema_config("default", updated_config, version=2)
# Verify updated schemas
assert len(processor.schemas) == 2
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_large_result_set_handling(self, processor, sample_schema_config):
"""Test handling of large query result sets"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "large_test_user"
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
}
}
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Measure query execution time
start_time = time.time()

View file

@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
# Arrange - Create realistic query request
request = StructuredQueryRequest(
question="Show me all customers from California who have made purchases over $500",
user="trustgraph",
collection="default"
)
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
assert objects_call_args.variables["state"] == "California"
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default"
# Verify response

View file

@ -1,17 +1,13 @@
[pytest]
testpaths = tests
python_paths = .
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
--cov=trustgraph
--cov-report=html
--cov-report=term-missing
# --disable-warnings
# --cov-fail-under=80
asyncio_mode = auto
markers =

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -78,10 +81,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
@ -93,7 +96,7 @@ class TestAgentServiceNonStreaming:
# Check thought message
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.chunk_type == "thought"
assert thought_response.message_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
@ -101,7 +104,7 @@ class TestAgentServiceNonStreaming:
# Check observation message
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.chunk_type == "observation"
assert observation_response.message_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -168,10 +174,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session and final
@ -183,7 +189,7 @@ class TestAgentServiceNonStreaming:
# Check final answer message
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.chunk_type == "answer"
assert answer_response.message_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
def _make_request(question="Test question",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Last history step should be the synthesis step
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Entry should be removed
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
"unknown", "question", "default",
)

View file

@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
assert responses[0].message_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
assert responses[0].message_type == "observation"
class TestAnswerCallbackMessageId:
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
assert responses[0].message_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):

View file

@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# correlation_id must be empty so it's not intercepted

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name):
return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
return PromptResult(response_type="text", text="research")
return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
@ -125,7 +126,6 @@ def make_base_request(**kwargs):
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",
@ -183,7 +183,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@ -267,7 +267,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@ -309,7 +309,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@ -355,10 +355,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name):
if name == "prompt-request":
@ -418,10 +421,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
mock_prompt_client.prompt.return_value = PromptResult(
response_type="json",
object={
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name):
if name == "prompt-request":
@ -475,7 +481,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@ -542,10 +548,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
"What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name):
if name == "prompt-request":
@ -590,7 +599,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@ -639,7 +648,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name):
if name == "prompt-request":

View file

@ -20,7 +20,7 @@ class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"message_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,

View file

@ -21,7 +21,6 @@ class MockProcessor:
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)

View file

@ -9,6 +9,7 @@ tool usage patterns.
import pytest
from unittest.mock import Mock, AsyncMock
import asyncio
import inspect
from collections import defaultdict
@ -133,7 +134,7 @@ class TestToolCoordinationLogic:
resolved_params[key] = value
# Execute tool
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**resolved_params)
else:
result = tool_function(**resolved_params)
@ -227,7 +228,7 @@ class TestToolCoordinationLogic:
# Simulate async execution with delay
await asyncio.sleep(0.001) # Small delay to simulate work
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**parameters)
else:
result = tool_function(**parameters)
@ -337,7 +338,7 @@ class TestToolCoordinationLogic:
if attempt > 0:
await asyncio.sleep(0.001 * (self.backoff_factor ** attempt))
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**parameters)
else:
result = tool_function(**parameters)

View file

@ -167,39 +167,28 @@ class TestToolServiceRequest:
"""Test cases for tool service request format"""
def test_request_format(self):
"""Test that request is properly formatted with user, config, and arguments"""
# Arrange
user = "alice"
"""Test that request is properly formatted with config and arguments"""
config_values = {"style": "pun", "collection": "jokes"}
arguments = {"topic": "programming"}
# Act - simulate request building
request = {
"user": user,
"config": json.dumps(config_values),
"arguments": json.dumps(arguments)
}
# Assert
assert request["user"] == "alice"
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
assert json.loads(request["arguments"]) == {"topic": "programming"}
def test_request_with_empty_config(self):
"""Test request when no config values are provided"""
# Arrange
user = "bob"
config_values = {}
arguments = {"query": "test"}
# Act
request = {
"user": user,
"config": json.dumps(config_values) if config_values else "{}",
"arguments": json.dumps(arguments) if arguments else "{}"
}
# Assert
assert request["config"] == "{}"
assert json.loads(request["arguments"]) == {"query": "test"}
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
assert map_topic_to_category("random topic") == "default"
assert map_topic_to_category("") == "default"
def test_joke_response_personalization(self):
"""Test that joke responses include user personalization"""
# Arrange
user = "alice"
def test_joke_response_format(self):
"""Test that joke response is formatted as expected"""
style = "pun"
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
# Act
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
response = f"Here's a {style} for you:\n\n{joke}"
# Assert
assert "Hey alice!" in response
assert "pun" in response
assert joke in response
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
def test_request_parsing(self):
"""Test parsing of incoming request"""
# Arrange
request_data = {
"user": "alice",
"config": '{"style": "pun"}',
"arguments": '{"topic": "programming"}'
}
# Act
user = request_data.get("user", "trustgraph")
config = json.loads(request_data["config"]) if request_data["config"] else {}
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
# Assert
assert user == "alice"
assert config == {"style": "pun"}
assert arguments == {"topic": "programming"}

View file

@ -1,6 +1,6 @@
"""
Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation.
and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {})
await svc.invoke({}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
calls = []
async def tracking_invoke(user, config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments})
async def tracking_invoke(config, arguments):
calls.append({"config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(user, config, arguments):
async def string_invoke(config, arguments):
return "hello world"
svc.invoke = string_invoke
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments):
async def dict_invoke(config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments):
async def failing_invoke(config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments):
async def rate_limited_invoke(config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments):
async def ok_invoke(config, arguments):
return "ok"
svc.invoke = ok_invoke
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
received = {}
async def capture_invoke(name, params):
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
))
result = await client.call(
user="alice",
config={"style": "pun"},
arguments={"topic": "cats"},
)
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={})
await client.call(config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config=None, arguments=None)
await client.call(config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config={}, arguments={}, timeout=30)
await client.call(config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
user="u", config={}, arguments={},
config={}, arguments={},
callback=AsyncMock(),
)
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -1,17 +1,14 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
- on_config_notify version comparison, type/workspace matching
- fetch_and_apply_config retry logic over per-workspace fetches
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify:
@pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
msg = _notify_msg(3, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
msg = _notify_msg(5, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
msg = _notify_msg(2, {"schema": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={"key": "value"},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None)
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called()
all_handler.assert_called_once()
all_handler.assert_not_called()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
async def test_multi_workspace_notify_invokes_handler_per_ws(
self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
mock_config = {}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
assert handler.call_count == 2
called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
side_effect=RuntimeError("Connection failed"),
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
async def test_applies_config_per_workspace(self, processor):
"""Startup fetch invokes handler once per workspace affected."""
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert h.call_count == 2
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def test_handler_without_types_skipped_at_startup(self, processor):
"""Handlers registered without types fetch nothing at startup."""
typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch():
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
mock_client = AsyncMock()
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
result = await client.query(
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert call_args.limit == 20 # Default limit
assert call_args.user == "trustgraph" # Default user
assert call_args.collection == "default" # Default collection
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')

View file

@ -0,0 +1,81 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base.flow import Flow
from trustgraph.base.parameter_spec import Parameter, ParameterSpec
from trustgraph.base.spec import Spec
def test_parameter_spec_is_a_spec_and_adds_parameter_value():
spec = ParameterSpec("temperature")
flow = MagicMock(parameter={})
processor = MagicMock()
spec.add(flow, processor, {"parameters": {"temperature": 0.7}})
assert isinstance(spec, Spec)
assert "temperature" in flow.parameter
assert isinstance(flow.parameter["temperature"], Parameter)
assert flow.parameter["temperature"].value == 0.7
def test_parameter_spec_defaults_missing_values_to_none():
spec = ParameterSpec("model")
flow = MagicMock(parameter={})
spec.add(flow, MagicMock(), {})
assert flow.parameter["model"].value is None
def test_parameter_start_and_stop_are_awaitable():
parameter = Parameter("value")
assert asyncio.run(parameter.start()) is None
assert asyncio.run(parameter.stop()) is None
def test_flow_initialization_calls_registered_specs():
spec_one = MagicMock()
spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1"
assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {}
assert flow.consumer == {}
assert flow.parameter == {}
spec_one.add.assert_called_once_with(flow, processor, {"answer": 42})
spec_two.add.assert_called_once_with(flow, processor, {"answer": 42})
def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock()
consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start())
asyncio.run(flow.stop())
consumer_one.start.assert_called_once_with()
consumer_two.start.assert_called_once_with()
consumer_one.stop.assert_called_once_with()
consumer_two.stop.assert_called_once_with()
def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value"
flow.parameter["parameter-only"] = Parameter("parameter-value")
flow.parameter["shared"] = Parameter("parameter-value")
assert flow("shared") == "producer-value"
assert flow("consumer-only") == "consumer-value"
assert flow("parameter-only") == "parameter-value"
assert flow("missing") is None

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)

View file

@ -1,58 +1,50 @@
"""
Unit tests for trustgraph.base.flow_processor
Starting small with a single test to verify basic functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.base.flow_processor import FlowProcessor
# Patches needed to let AsyncProcessor.__init__ run without real
# infrastructure while still setting self.id correctly.
ASYNC_PROCESSOR_PATCHES = [
patch('trustgraph.base.async_processor.get_pubsub', return_value=MagicMock()),
patch('trustgraph.base.async_processor.ProcessorMetrics', return_value=MagicMock()),
patch('trustgraph.base.async_processor.Consumer', return_value=MagicMock()),
]
def with_async_processor_patches(func):
"""Apply all AsyncProcessor dependency patches to a test."""
for p in reversed(ASYNC_PROCESSOR_PATCHES):
func = p(func)
return func
class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
"""Test FlowProcessor base class functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_flow_processor_initialization_basic(self, mock_register_config, mock_async_init):
@with_async_processor_patches
async def test_flow_processor_initialization_basic(self, *mocks):
"""Test basic FlowProcessor initialization"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
# Act
processor = FlowProcessor(**config)
# Assert
# Verify AsyncProcessor.__init__ was called
mock_async_init.assert_called_once()
# Verify register_config_handler was called with the correct handler
mock_register_config.assert_called_once_with(
processor.on_configure_flows, types=["active-flow"]
)
# Verify FlowProcessor-specific initialization
assert hasattr(processor, 'flows')
assert processor.id == 'test-flow-processor'
assert processor.flows == {}
assert hasattr(processor, 'specifications')
assert processor.specifications == []
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_register_specification(self, mock_register_config, mock_async_init):
@with_async_processor_patches
async def test_register_specification(self, *mocks):
"""Test registering a specification"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
@ -62,288 +54,210 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_spec = MagicMock()
mock_spec.name = 'test-spec'
# Act
processor.register_specification(mock_spec)
# Assert
assert len(processor.specifications) == 1
assert processor.specifications[0] == mock_spec
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_start_flow(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_start_flow(self, *mocks):
"""Test starting a flow"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor' # Set id for Flow creation
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert
assert flow_name in processor.flows
# Verify Flow was created with correct parameters
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
# Verify the flow's start method was called
assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', flow_name, "default", processor, flow_defn
)
mock_flow.start.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_stop_flow(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_stop_flow(self, *mocks):
"""Test stopping a flow"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
await processor.start_flow("default", flow_name, {'config': 'test-config'})
# Start a flow first
await processor.start_flow(flow_name, flow_defn)
# Act
await processor.stop_flow(flow_name)
await processor.stop_flow("default", flow_name)
# Assert
assert flow_name not in processor.flows
assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_stop_flow_not_exists(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_stop_flow_not_exists(self, *mocks):
"""Test stopping a flow that doesn't exist"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act - should not raise an exception
await processor.stop_flow('non-existent-flow')
# Assert - flows dict should still be empty
await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {}
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_basic(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_basic(self, *mocks):
"""Test basic flow configuration handling"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
# Configuration with flows for this processor
flow_config = {
'test-flow': {'config': 'test-config'}
}
config_data = {
'active-flow': {
'test-processor': '{"test-flow": {"config": "test-config"}}'
'processor:test-processor': {
'test-flow': '{"config": "test-config"}'
}
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
assert 'test-flow' in processor.flows
mock_flow_class.assert_called_once_with('test-processor', 'test-flow', processor, {'config': 'test-config'})
await processor.on_configure_flows("default", config_data, version=1)
assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'}
)
mock_flow.start.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_no_config(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_no_config(self, *mocks):
"""Test flow configuration handling when no config exists for this processor"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Configuration without flows for this processor
config_data = {
'active-flow': {
'other-processor': '{"other-flow": {"config": "other-config"}}'
'processor:other-processor': {
'other-flow': '{"config": "other-config"}'
}
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
mock_flow_class.assert_not_called()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_invalid_config(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_invalid_config(self, *mocks):
"""Test flow configuration handling with invalid config format"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Configuration without active-flow key
config_data = {
'other-data': 'some-value'
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
mock_flow_class.assert_not_called()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_start_and_stop(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_start_and_stop(self, *mocks):
"""Test flow configuration handling with starting and stopping flows"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow1 = AsyncMock()
mock_flow2 = AsyncMock()
mock_flow_class.side_effect = [mock_flow1, mock_flow2]
# First configuration - start flow1
config_data1 = {
'active-flow': {
'test-processor': '{"flow1": {"config": "config1"}}'
'processor:test-processor': {
'flow1': '{"config": "config1"}'
}
}
await processor.on_configure_flows(config_data1, version=1)
await processor.on_configure_flows("default", config_data1, version=1)
# Second configuration - stop flow1, start flow2
config_data2 = {
'active-flow': {
'test-processor': '{"flow2": {"config": "config2"}}'
'processor:test-processor': {
'flow2': '{"config": "config2"}'
}
}
# Act
await processor.on_configure_flows(config_data2, version=2)
# Assert
# flow1 should be stopped and removed
assert 'flow1' not in processor.flows
await processor.on_configure_flows("default", config_data2, version=2)
assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once()
# flow2 should be started and added
assert 'flow2' in processor.flows
assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once()
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
@with_async_processor_patches
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
async def test_start_calls_parent(self, mock_parent_start, mock_register_config, mock_async_init):
async def test_start_calls_parent(self, mock_parent_start, *mocks):
"""Test that start() calls parent start method"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_parent_start.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act
await processor.start()
# Assert
mock_parent_start.assert_called_once()
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_add_args_calls_parent(self, mock_register_config, mock_async_init):
async def test_add_args_calls_parent(self):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.async_processor.AsyncProcessor.add_args') as mock_parent_add_args:
FlowProcessor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])

View file

@ -0,0 +1,40 @@
from trustgraph.i18n import get_language_pack, get_translator, normalize_language
def test_normalize_language_handles_regions_and_accept_language():
assert normalize_language(None) == "en"
assert normalize_language("") == "en"
assert normalize_language("es-ES") == "es"
assert normalize_language("pt-BR") == "pt"
assert normalize_language("zh") == "zh-cn"
assert normalize_language("es-ES,es;q=0.9,en;q=0.8") == "es"
assert normalize_language("unknown") == "en"
def test_language_pack_loads_from_resources():
pack = get_language_pack("en")
assert isinstance(pack, dict)
# Key should exist and map to a non-empty string.
title = pack.get("cli.verify_system_status.title")
assert isinstance(title, str)
assert title.strip() != ""
def test_translator_formats_placeholders():
tr = get_translator("en")
out = tr.t(
"cli.verify_system_status.checking_attempt",
name="Pulsar",
attempt=2,
)
assert "Pulsar" in out
assert "2" in out
def test_translator_falls_back_to_key_for_unknown_keys():
tr = get_translator("en")
assert tr.t("missing.key") == "missing.key"

View file

@ -0,0 +1,130 @@
import argparse
import logging
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
from trustgraph.base.logging import add_logging_args, setup_logging
def test_add_logging_args_uses_environment_defaults(monkeypatch):
monkeypatch.setenv("LOKI_URL", "http://example.test/loki")
monkeypatch.setenv("LOKI_USERNAME", "user")
monkeypatch.setenv("LOKI_PASSWORD", "pass")
parser = argparse.ArgumentParser()
add_logging_args(parser)
args = parser.parse_args([])
assert args.log_level == "INFO"
assert args.loki_enabled is True
assert args.loki_url == "http://example.test/loki"
assert args.loki_username == "user"
assert args.loki_password == "pass"
def test_add_logging_args_supports_disabling_loki():
parser = argparse.ArgumentParser()
add_logging_args(parser)
args = parser.parse_args(["--no-loki-enabled"])
assert args.loki_enabled is False
def test_setup_logging_without_loki_configures_console(monkeypatch):
basic_config = MagicMock()
logger = MagicMock()
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
setup_logging({"log_level": "debug", "loki_enabled": False, "id": "processor-1"})
kwargs = basic_config.call_args.kwargs
assert kwargs["level"] == logging.DEBUG
assert kwargs["force"] is True
assert "%(processor_id)s" in kwargs["format"]
assert len(kwargs["handlers"]) == 1
logger.info.assert_called_once_with("Logging configured with level: debug")
def test_setup_logging_with_loki_enables_queue_listener(monkeypatch):
basic_config = MagicMock()
root_logger = MagicMock()
module_logger = MagicMock()
urllib3_logger = MagicMock()
connectionpool_logger = MagicMock()
queue_handler = MagicMock()
queue_listener = MagicMock()
loki_handler = MagicMock()
noisy_logger = MagicMock()
logger_map = {
None: root_logger,
"trustgraph.base.logging": module_logger,
"urllib3": urllib3_logger,
"urllib3.connectionpool": connectionpool_logger,
"pika": noisy_logger,
"cassandra": noisy_logger,
}
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger_map[name])
monkeypatch.setattr(
logging.handlers,
"QueueHandler",
MagicMock(return_value=queue_handler),
)
monkeypatch.setattr(
logging.handlers,
"QueueListener",
MagicMock(return_value=queue_listener),
)
monkeypatch.setitem(
sys.modules,
"logging_loki",
SimpleNamespace(LokiHandler=MagicMock(return_value=loki_handler)),
)
setup_logging(
{
"log_level": "INFO",
"loki_enabled": True,
"loki_url": "http://loki.test/push",
"loki_username": "user",
"loki_password": "pass",
"id": "processor-1",
}
)
assert root_logger.loki_queue_listener is queue_listener
queue_listener.start.assert_called_once_with()
urllib3_logger.setLevel.assert_called_once_with(logging.WARNING)
connectionpool_logger.setLevel.assert_called_once_with(logging.WARNING)
module_logger.info.assert_any_call("Logging configured with level: INFO")
module_logger.info.assert_any_call("Loki logging enabled: http://loki.test/push")
def test_setup_logging_falls_back_when_loki_module_missing(monkeypatch, capsys):
basic_config = MagicMock()
logger = MagicMock()
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
monkeypatch.delitem(sys.modules, "logging_loki", raising=False)
real_import = __import__
def fake_import(name, *args, **kwargs):
if name == "logging_loki":
raise ImportError("missing")
return real_import(name, *args, **kwargs)
monkeypatch.setattr("builtins.__import__", fake_import)
setup_logging({"log_level": "INFO", "loki_enabled": True, "id": "processor-1"})
output = capsys.readouterr().out
assert "python-logging-loki not installed" in output
logger.warning.assert_called_once_with("Loki logging requested but not available")

View file

@ -0,0 +1,143 @@
from unittest.mock import MagicMock
import pytest
from trustgraph.base import metrics
@pytest.fixture(autouse=True)
def reset_metric_singletons():
"""Temporarily remove metric singletons so each test can inject mocks.
Saves any existing class-level metrics and restores them after the test
so that later tests in the same process still find the hasattr() guard
intact deleting without restoring causes every subsequent Processor()
construction to re-register the same Prometheus metric name, which raises
ValueError: Duplicated timeseries.
"""
classes_and_attrs = {
metrics.ConsumerMetrics: [
"state_metric",
"request_metric",
"processing_metric",
"rate_limit_metric",
],
metrics.ProducerMetrics: ["producer_metric"],
metrics.ProcessorMetrics: ["processor_metric"],
metrics.SubscriberMetrics: [
"state_metric",
"received_metric",
"dropped_metric",
],
}
saved = {}
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
saved[(cls, attr)] = getattr(cls, attr)
delattr(cls, attr)
yield
# Remove anything the test may have set, then restore originals
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
delattr(cls, attr)
for (cls, attr), value in saved.items():
setattr(cls, attr, value)
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
enum_factory = MagicMock()
histogram_factory = MagicMock()
counter_factory = MagicMock()
state_labels = MagicMock()
request_labels = MagicMock()
processing_labels = MagicMock()
rate_limit_labels = MagicMock()
timer = MagicMock()
enum_factory.return_value.labels.return_value = state_labels
histogram_factory.return_value.labels.return_value = request_labels
request_labels.time.return_value = timer
counter_factory.side_effect = [
MagicMock(labels=MagicMock(return_value=processing_labels)),
MagicMock(labels=MagicMock(return_value=rate_limit_labels)),
]
monkeypatch.setattr(metrics, "Enum", enum_factory)
monkeypatch.setattr(metrics, "Histogram", histogram_factory)
monkeypatch.setattr(metrics, "Counter", counter_factory)
first = metrics.ConsumerMetrics("proc", "flow", "name")
second = metrics.ConsumerMetrics("proc-2", "flow-2", "name-2")
assert enum_factory.call_count == 1
assert histogram_factory.call_count == 1
assert counter_factory.call_count == 2
first.process("ok")
first.rate_limit()
first.state("running")
assert first.record_time() is timer
processing_labels.inc.assert_called_once_with()
rate_limit_labels.inc.assert_called_once_with()
state_labels.state.assert_called_once_with("running")
def test_producer_metrics_increments_counter_once(monkeypatch):
counter_factory = MagicMock()
labels = MagicMock()
counter_factory.return_value.labels.return_value = labels
monkeypatch.setattr(metrics, "Counter", counter_factory)
producer_metrics = metrics.ProducerMetrics("proc", "flow", "output")
producer_metrics.inc()
counter_factory.assert_called_once()
labels.inc.assert_called_once_with()
def test_processor_metrics_reports_info(monkeypatch):
info_factory = MagicMock()
labels = MagicMock()
info_factory.return_value.labels.return_value = labels
monkeypatch.setattr(metrics, "Info", info_factory)
processor_metrics = metrics.ProcessorMetrics("proc")
processor_metrics.info({"kind": "test"})
info_factory.assert_called_once()
labels.info.assert_called_once_with({"kind": "test"})
def test_subscriber_metrics_tracks_received_state_and_dropped(monkeypatch):
enum_factory = MagicMock()
counter_factory = MagicMock()
state_labels = MagicMock()
received_labels = MagicMock()
dropped_labels = MagicMock()
enum_factory.return_value.labels.return_value = state_labels
counter_factory.side_effect = [
MagicMock(labels=MagicMock(return_value=received_labels)),
MagicMock(labels=MagicMock(return_value=dropped_labels)),
]
monkeypatch.setattr(metrics, "Enum", enum_factory)
monkeypatch.setattr(metrics, "Counter", counter_factory)
subscriber_metrics = metrics.SubscriberMetrics("proc", "flow", "input")
subscriber_metrics.received()
subscriber_metrics.state("running")
subscriber_metrics.dropped("ignored")
received_labels.inc.assert_called_once_with()
dropped_labels.inc.assert_called_once_with()
state_labels.state.assert_called_once_with("running")

View file

@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates graceful shutdown
async def mock_run_graceful():
# Honor the readiness contract: real run() signals _ready
# after binding the consumer, so start() can unblock. Mocks
# of run() must do the same or start() hangs forever.
subscriber._ready.set_result(None)
# Process messages while running, then drain
while subscriber.running or subscriber.draining:
if subscriber.draining:
@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates cleanup of pending acks
async def mock_run_cleanup():
# Honor the readiness contract — see test_subscriber_graceful_shutdown.
subscriber._ready.set_result(None)
while subscriber.running or subscriber.draining:
await asyncio.sleep(0.05)
if subscriber.draining:
@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
msg1 = await queue1.get()
msg_all = await queue_all.get()
assert msg1 == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}

View file

@ -0,0 +1,189 @@
"""
Regression tests for Subscriber.start() readiness barrier.
Background: prior to the eager-connect fix, Subscriber.start() created
the run() task and returned immediately. The underlying backend consumer
was lazily connected on its first receive() call, which left a setup
race for request/response clients using ephemeral per-subscriber response
queues (RabbitMQ auto-delete exclusive queues): the request would be
published before the response queue was bound, and the broker would
silently drop the reply. fetch_config(), document-embeddings, and
api-gateway all hit this with "Failed to fetch config on notify" /
"Request timeout exception" symptoms.
These tests pin the readiness contract:
await subscriber.start()
# at this point, consumer.ensure_connected() MUST have run
so that any future change which removes the eager bind, or moves it
back to lazy initialisation, fails CI loudly.
"""
import asyncio
import pytest
from unittest.mock import MagicMock
from trustgraph.base.subscriber import Subscriber
def _make_backend(ensure_connected_side_effect=None,
receive_side_effect=None):
"""Build a fake backend whose consumer records ensure_connected /
receive calls. ensure_connected_side_effect lets a test inject a
delay or exception."""
backend = MagicMock()
consumer = MagicMock()
consumer.ensure_connected = MagicMock(
side_effect=ensure_connected_side_effect,
)
# By default receive raises a timeout-style exception that the
# subscriber loop is supposed to swallow as a "no message yet" — this
# keeps the subscriber idling cleanly while the test inspects state.
if receive_side_effect is None:
receive_side_effect = TimeoutError("No message received within timeout")
consumer.receive = MagicMock(side_effect=receive_side_effect)
consumer.acknowledge = MagicMock()
consumer.negative_acknowledge = MagicMock()
consumer.pause_message_listener = MagicMock()
consumer.unsubscribe = MagicMock()
consumer.close = MagicMock()
backend.create_consumer.return_value = consumer
return backend, consumer
def _make_subscriber(backend):
return Subscriber(
backend=backend,
topic="response:tg:config",
subscription="test-sub",
consumer_name="test-consumer",
schema=dict,
max_size=10,
drain_timeout=1.0,
backpressure_strategy="block",
)
class TestSubscriberReadiness:
@pytest.mark.asyncio
async def test_start_calls_ensure_connected_before_returning(self):
"""The barrier: ensure_connected must have been invoked at least
once by the time start() returns."""
backend, consumer = _make_backend()
subscriber = _make_subscriber(backend)
await subscriber.start()
try:
consumer.ensure_connected.assert_called_once()
finally:
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_blocks_until_ensure_connected_completes(self):
"""If ensure_connected is slow, start() must wait for it. This is
the actual race-condition guard it would have failed against
the buggy version where start() returned before run() had even
scheduled the consumer creation."""
connect_started = asyncio.Event()
release_connect = asyncio.Event()
# ensure_connected runs in the executor thread, so we need a
# threading-safe gate. Use a simple busy-wait on a flag set by
# the asyncio side via call_soon_threadsafe — but the simpler
# path is to give it a sleep and observe ordering.
import threading
gate = threading.Event()
def slow_connect():
connect_started.set() # safe: only mutates the Event flag
gate.wait(timeout=2.0)
backend, consumer = _make_backend(
ensure_connected_side_effect=slow_connect,
)
subscriber = _make_subscriber(backend)
start_task = asyncio.create_task(subscriber.start())
# Wait until ensure_connected has begun executing.
await asyncio.wait_for(connect_started.wait(), timeout=2.0)
# ensure_connected is in flight — start() must NOT have returned.
assert not start_task.done(), (
"start() returned before ensure_connected() completed — "
"the readiness barrier is broken and the request/response "
"race condition is back."
)
# Release the gate; start() should now complete promptly.
gate.set()
await asyncio.wait_for(start_task, timeout=2.0)
consumer.ensure_connected.assert_called_once()
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_propagates_consumer_creation_failure(self):
"""If create_consumer() raises, start() must surface the error
rather than hang on the readiness future. The old code path
retried indefinitely inside run() and never let start() unblock."""
backend = MagicMock()
backend.create_consumer.side_effect = RuntimeError("broker down")
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="broker down"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_start_propagates_ensure_connected_failure(self):
"""Same contract for an ensure_connected() that raises (e.g. the
broker is up but the queue declare/bind fails)."""
backend, consumer = _make_backend(
ensure_connected_side_effect=RuntimeError("queue declare failed"),
)
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="queue declare failed"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_ensure_connected_runs_before_subscriber_running_log(self):
"""Subtle ordering: ensure_connected MUST happen before the
receive loop, so that any reply is captured. We assert this by
checking ensure_connected was called before any receive call."""
call_order = []
def record_ensure():
call_order.append("ensure_connected")
def record_receive(*args, **kwargs):
call_order.append("receive")
raise TimeoutError("No message received within timeout")
backend, consumer = _make_backend(
ensure_connected_side_effect=record_ensure,
receive_side_effect=record_receive,
)
subscriber = _make_subscriber(backend)
await subscriber.start()
# Give the receive loop a tick to run at least once.
await asyncio.sleep(0.05)
await subscriber.stop()
# ensure_connected must come first; receive may not have happened
# yet on a fast machine, but if it did, it must come after.
assert call_order, "neither ensure_connected nor receive was called"
assert call_order[0] == "ensure_connected"

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text."""
metadata = Metadata(
id="test-doc-1",
user="test-user",
collection="test-collection"
)
text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks."""
metadata = Metadata(
id="test-doc-long",
user="test-user",
collection="test-collection"
)
# Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters."""
metadata = Metadata(
id="test-doc-unicode",
user="test-user",
collection="test-collection"
)
text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing."""
metadata = Metadata(
id="test-doc-empty",
user="test-user",
collection="test-collection"
)
return TextDocument(

View file

@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -184,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"
@ -195,15 +195,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500,
"chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -241,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -184,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"
@ -195,15 +195,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400,
"chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -245,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com',
config_type='prompt',
format_type='json',
token=None
token=None,
workspace='default'
)
def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/',
config_type='prompt',
format_type='text',
token=None
token=None,
workspace='default'
)
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt',
key='template-1',
format_type='json',
token=None
token=None,
workspace='default'
)
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt',
key='new-template',
value='Custom prompt: {input}',
token=None
token=None,
workspace='default'
)
def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt',
key='stdin-template',
value=stdin_content,
token=None
token=None,
workspace='default'
)
def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com',
config_type='prompt',
key='old-template',
token=None
token=None,
workspace='default'
)

View file

@ -48,7 +48,7 @@ def knowledge_loader():
return KnowledgeLoader(
files=["test.ttl"],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc-123",
url="http://test.example.com/",
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
loader = KnowledgeLoader(
files=["file1.ttl", "file2.ttl"],
flow="my-flow",
user="user1",
workspace="user1",
collection="col1",
document_id="doc1",
url="http://example.com/",
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
assert loader.files == ["file1.ttl", "file2.ttl"]
assert loader.flow == "my-flow"
assert loader.user == "user1"
assert loader.workspace == "user1"
assert loader.collection == "col1"
assert loader.document_id == "doc1"
assert loader.url == "http://example.com/"
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[f.name],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/",
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters
mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
token="test-token",
workspace="test-user"
)
# Verify bulk client was obtained
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
call_args = mock_bulk.import_triples.call_args
assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc"
assert call_args[1]['metadata']['user'] == "test-user"
assert call_args[1]['metadata']['collection'] == "test-collection"
# Verify import_entity_contexts was called
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
'tg-load-knowledge',
'-i', 'doc-123',
'-f', 'my-flow',
'-U', 'my-user',
'-w', 'my-user',
'-C', 'my-collection',
'-u', 'http://custom.example.com/',
'-t', 'my-token',
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
token='my-token',
flow='my-flow',
files=['file1.ttl', 'file2.ttl'],
user='my-user',
workspace='my-user',
collection='my-collection'
)
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
# Verify defaults were used
call_args = mock_loader_class.call_args[1]
assert call_args['flow'] == 'default'
assert call_args['user'] == 'trustgraph'
assert call_args['workspace'] == 'default'
assert call_args['collection'] == 'default'
assert call_args['url'] == 'http://localhost:8088/'
assert call_args['token'] is None
@ -287,7 +287,7 @@ class TestErrorHandling:
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None)
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery:

View file

@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
# Act
result = client.request(
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vector=vector,
limit=10,
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vector=vector,
limit=10,

View file

@ -31,7 +31,6 @@ def _make_query(
query = Query(
rag=rag,
user="test-user",
collection="test-collection",
verbose=False,
entity_limit=entity_limit,
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20

View file

@ -28,10 +28,12 @@ def mock_flow_config():
"""Mock flow configuration."""
mock_config = Mock()
mock_config.flows = {
"test-flow": {
"interfaces": {
"triples-store": "test-triples-queue",
"graph-embeddings-store": "test-ge-queue"
"test-user": {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
}
}
}
}
@ -43,7 +45,7 @@ def mock_flow_config():
def mock_request():
"""Mock knowledge load request."""
request = Mock()
request.user = "test-user"
request.workspace = "test-user"
request.id = "test-doc-id"
request.collection = "test-collection"
request.flow = "test-flow"
@ -71,7 +73,6 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
triples=[
@ -90,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
entities=[
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
# Create request with None collection
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = None # Should fall back to "default"
mock_request.flow = "test-flow"
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core validates flow configuration before processing."""
# Request with invalid flow
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = "test-collection"
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
# Test missing ID
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = None # Missing
mock_request.collection = "test-collection"
mock_request.flow = "test-flow"
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
"""Test that get_kg_core preserves collection field from stored data."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_list_kg_cores(self, knowledge_manager):
"""Test listing knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_respond = AsyncMock()
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_delete_kg_core(self, knowledge_manager):
"""Test deleting knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data
content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
]
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,

View file

@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_basic(self):
"""Test basic collection name creation"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="test_collection",
prefix="doc"
)
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_special_characters(self):
"""Test collection name creation with special characters that need sanitization"""
result = make_safe_collection_name(
user="user@domain.com",
workspace="user@domain.com",
collection="test-collection.v2",
prefix="entity"
)
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_unicode(self):
"""Test collection name creation with Unicode characters"""
result = make_safe_collection_name(
user="测试用户",
workspace="测试用户",
collection="colección_española",
prefix="doc"
)
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_spaces(self):
"""Test collection name creation with spaces"""
result = make_safe_collection_name(
user="test user",
workspace="test user",
collection="my test collection",
prefix="entity"
)
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
"""Test collection name creation with multiple consecutive special characters"""
result = make_safe_collection_name(
user="user@@@domain!!!",
workspace="user@@@domain!!!",
collection="test---collection...v2",
prefix="doc"
)
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
"""Test collection name creation with leading/trailing special characters"""
result = make_safe_collection_name(
user="__test_user__",
workspace="__test_user__",
collection="@@test_collection##",
prefix="entity"
)
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_user(self):
"""Test collection name creation with empty user (should fallback to 'default')"""
result = make_safe_collection_name(
user="",
workspace="",
collection="test_collection",
prefix="doc"
)
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_collection(self):
"""Test collection name creation with empty collection (should fallback to 'default')"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="",
prefix="doc"
)
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_both_empty(self):
"""Test collection name creation with both user and collection empty"""
result = make_safe_collection_name(
user="",
workspace="",
collection="",
prefix="doc"
)
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_only_special_characters(self):
"""Test collection name creation with only special characters (should fallback to 'default')"""
result = make_safe_collection_name(
user="@@@!!!",
workspace="@@@!!!",
collection="---###",
prefix="entity"
)
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_whitespace_only(self):
"""Test collection name creation with whitespace-only strings"""
result = make_safe_collection_name(
user=" \n\t ",
workspace=" \n\t ",
collection=" \r\n ",
prefix="doc"
)
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
"""Test collection name creation with mixed valid and invalid characters"""
result = make_safe_collection_name(
user="user123@test",
workspace="user123@test",
collection="coll_2023.v1",
prefix="entity"
)
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
long_collection = "b" * 100
result = make_safe_collection_name(
user=long_user,
workspace=long_user,
collection=long_collection,
prefix="doc"
)
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_numeric_values(self):
"""Test collection name creation with numeric user/collection values"""
result = make_safe_collection_name(
user="user123",
workspace="user123",
collection="collection456",
prefix="doc"
)
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_case_sensitivity(self):
"""Test that collection name creation preserves case"""
result = make_safe_collection_name(
user="TestUser",
workspace="TestUser",
collection="TestCollection",
prefix="Doc"
)

View file

@ -20,9 +20,8 @@ def processor():
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio
async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
msg = _make_chunk_message(collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
assert 'customers' in processor.schemas["default"]
assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {}
assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
processor.schemas["default"] = {
'customers': RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
}
metadata = MagicMock()
metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -0,0 +1,200 @@
"""
Unit tests for extract_with_simplified_format.
Regression guard for the bug where the extractor read
``result.object`` (singular, used for response_type="json") instead of
``result.objects`` (plural, used for response_type="jsonl"). The
extract-with-ontologies prompt is JSONL, so reading the wrong field
silently dropped every extraction and left the knowledge graph
populated only by ontology schema + document provenance.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.base import PromptResult
@pytest.fixture
def extractor():
"""Create a Processor instance without running its heavy __init__.
Matches the pattern used in test_prompt_and_extraction.py: only
the attributes the code under test touches need to be set.
"""
ex = object.__new__(Processor)
ex.URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
return ex
@pytest.fixture
def food_subset():
"""A minimal food ontology subset the extracted entities reference."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe.",
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food.",
},
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"comment": "Relates a recipe to its ingredients.",
"domain": "Recipe",
"range": "Food",
},
},
datatype_properties={},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/",
},
)
def _flow_with_prompt_result(prompt_result):
"""Build the ``flow(name)`` callable the extractor invokes.
``extract_with_simplified_format`` calls
``flow("prompt-request").prompt(...)`` so we need ``flow`` to be
callable, return an object whose ``.prompt`` is an AsyncMock that
resolves to ``prompt_result``.
"""
prompt_service = MagicMock()
prompt_service.prompt = AsyncMock(return_value=prompt_result)
def flow(name):
assert name == "prompt-request", (
f"extractor should only invoke flow('prompt-request'), "
f"got {name!r}"
)
return prompt_service
return flow, prompt_service.prompt
class TestReadsObjectsForJsonlPrompt:
"""extract-with-ontologies is a JSONL prompt; the extractor must
read ``result.objects``, not ``result.object``."""
async def test_populated_objects_produces_triples(
self, extractor, food_subset,
):
"""Happy path: PromptResult with populated .objects -> non-empty
triples list."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[
{"type": "entity", "entity": "Cornish Pasty",
"entity_type": "Recipe"},
{"type": "entity", "entity": "beef",
"entity_type": "Food"},
{"type": "relationship",
"subject": "Cornish Pasty", "subject_type": "Recipe",
"relation": "ingredients",
"object": "beef", "object_type": "Food"},
],
)
flow, prompt_mock = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "some chunk", food_subset, {"text": "some chunk"},
)
prompt_mock.assert_awaited_once()
assert triples, (
"extract_with_simplified_format returned no triples; if "
"this fails, the extractor is probably reading .object "
"instead of .objects again"
)
async def test_none_objects_returns_empty_without_crashing(
self, extractor, food_subset,
):
"""The exact shape that hit production on v2.3: the extractor
was reading ``.object`` for a JSONL prompt, which returned
``None`` and tripped the parser's 'Unexpected response type'
path. With the fix we read ``.objects``; if that's also
``None`` we must still return ``[]`` cleanly, not crash."""
prompt_result = PromptResult(
response_type="jsonl",
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_empty_objects_returns_empty(
self, extractor, food_subset,
):
"""Valid JSONL response with zero entries should yield zero
triples, not raise."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[],
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_ignores_object_field_for_jsonl_prompt(
self, extractor, food_subset,
):
"""If ``.object`` is somehow set but ``.objects`` is None, the
extractor must not silently fall back to ``.object``. This
guards against a well-meaning regression that "helpfully"
re-adds fallback fields.
The extractor should read only ``.objects`` for this prompt;
when that is None we expect the empty-result path.
"""
prompt_result = PromptResult(
response_type="json",
object={"not": "the field we should be reading"},
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == [], (
"Extractor fell back to .object for a JSONL prompt — "
"this is the regression shape we are trying to prevent"
)

View file

@ -231,6 +231,52 @@ class TestTripleValidation:
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid == expected, f"Validation of {predicate} should be {expected}"
def test_validates_domain_correctly_with_entity_types(self, extractor, sample_ontology_subset):
"""Test domain validation correctly compares against extracted entity_types."""
subject = "my-recipe"
predicate = "produces"
object_val = "my-food"
# Proper domain for produces is Recipe
entity_types = {
"my-recipe": "Recipe",
"my-food": "Food"
}
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
assert is_valid, "Valid domain should be accepted"
# Invalid domain
entity_types_invalid = {
"my-recipe": "Ingredient",
"my-food": "Food"
}
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
assert not is_invalid, "Invalid domain should be rejected"
def test_validates_range_correctly_with_entity_types(self, extractor, sample_ontology_subset):
"""Test range validation correctly compares against extracted entity_types."""
subject = "my-recipe"
predicate = "produces"
object_val = "my-food"
# Proper range for produces is Food
entity_types = {
"my-recipe": "Recipe",
"my-food": "Food"
}
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
assert is_valid, "Valid range should be accepted"
# Invalid range
entity_types_invalid = {
"my-recipe": "Recipe",
"my-food": "Recipe"
}
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
assert not is_invalid, "Invalid range should be rejected"
class TestTripleParsing:
"""Test suite for parsing triples from LLM responses."""

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@ -33,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -51,8 +51,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
return_value=wrapped
)
def flow(name):
@ -224,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -233,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
@ -242,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
"text", meta_id="c-2", root="r-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@ -37,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -58,7 +58,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
)
def flow(name):
@ -185,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -194,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver:
"""Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Notify with newer version fetches each affected workspace."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1
msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
"""Older-version notifies are ignored."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=3, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0
msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
"""Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch
assert len(fetch_calls) == 0
msg = _notify(2, {"prompt": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow", "active-flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
"""on_config_notify swallows exceptions from message decode."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
"flow2": '{"name": "test_flow_2"}',
}
}
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
"flow1": '{"name": "test_flow_1"}',
}
}
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Empty workspace config clears any local flow state."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {}
assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
"""start_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.start_flow = Mock()
handler1.start_flow = AsyncMock()
handler2 = Mock()
handler2.start_flow = Mock()
handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
handler1.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions"""
"""Handler exceptions in start_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data)
handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers"""
"""stop_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.stop_flow = Mock()
handler1.stop_flow = AsyncMock()
handler2 = Mock()
handler2.stop_flow = Mock()
handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
handler1.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions"""
"""Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data)
handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task')
@pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
"flow3": '{"name": "test_flow_3"}',
}
}
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_calls.append((workspace, id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
ws_flows = config_receiver.flows["default"]
assert "flow1" not in ws_flows
assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"
assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -0,0 +1,406 @@
"""
Round-trip unit tests for the core msgpack import/export gateway endpoints.
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
the responder callback and packs them into msgpack tuples. The kg-core
import endpoint takes msgpack tuples back off the wire and rebuilds
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
(whose translator decodes them into real dataclasses).
Regression coverage: the previous wire format used `"vectors"` (plural)
in the entity blobs and embedded a stale `"m"` field that referenced the
removed `Metadata.metadata` triples-list field. The export side hit a
KeyError on first message; the import side built dicts that the
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
tests pin both halves of the wire protocol.
"""
import msgpack
import pytest
from unittest.mock import AsyncMock, Mock, patch
from trustgraph.gateway.dispatch.core_export import CoreExport
from trustgraph.gateway.dispatch.core_import import CoreImport
# ---------------------------------------------------------------------------
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
# would emit). The vector wire key is *singular* on purpose; the export
# side previously read the wrong key and crashed.
# ---------------------------------------------------------------------------
def _ge_response_dict():
return {
"graph-embeddings": {
"metadata": {
"id": "doc-1",
"root": "",
"collection": "testcoll",
},
"entities": [
{
"entity": {"t": "i", "i": "http://example.org/alice"},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"t": "i", "i": "http://example.org/bob"},
"vector": [0.4, 0.5, 0.6],
},
],
}
}
def _triples_response_dict():
return {
"triples": {
"metadata": {
"id": "doc-1",
"root": "",
"collection": "testcoll",
},
"triples": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
}
}
def _make_request(id_="doc-1", workspace="alice"):
request = Mock()
request.query = {"id": id_, "workspace": workspace}
return request
def _make_data_reader(payload: bytes):
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
chunks = [payload, b""]
data = Mock()
async def fake_read(n):
return chunks.pop(0) if chunks else b""
data.read = fake_read
return data
# ---------------------------------------------------------------------------
# Export side: translator-shaped dict -> msgpack bytes
# ---------------------------------------------------------------------------
class TestCoreExportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_graph_embeddings_with_singular_vector(
self, mock_kr_class,
):
"""The export side must read `ent["vector"]` and emit `v`. The
previous bug was reading `ent["vectors"]` which KeyErrored against
the translator output."""
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_ge_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=error,
ok=ok,
request=_make_request(),
)
# Did not raise, did not call error()
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "ge"
# Metadata envelope: only id/collection — no stale `m["m"]`.
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
# Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_triples_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(), error=error, ok=ok, request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "t"
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
assert len(payload["t"]) == 1
# ---------------------------------------------------------------------------
# Import side: msgpack bytes -> translator-shaped dict
# ---------------------------------------------------------------------------
class TestCoreImportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_graph_embeddings_to_singular_vector(
self, mock_kr_class,
):
"""The import side must build dicts whose entity blobs have the
singular `vector` key that's what the KnowledgeRequestTranslator
decode side reads. Previous bug emitted `vectors`."""
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
# Build a msgpack tuple matching the new wire format
payload = msgpack.packb((
"ge",
{
"m": {"i": "doc-1", "c": "testcoll"},
"e": [
{
"e": {"t": "i", "i": "http://example.org/alice"},
"v": [0.1, 0.2, 0.3],
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
assert req["operation"] == "put-kg-core"
assert req["workspace"] == "alice"
assert req["id"] == "doc-1"
ge = req["graph-embeddings"]
# Metadata envelope must NOT contain a stale `metadata` key
# referencing the removed Metadata.metadata field.
assert "metadata" not in ge["metadata"]
assert ge["metadata"] == {
"id": "doc-1",
"collection": "default",
}
# Entity blob carries the singular `vector` key
assert len(ge["entities"]) == 1
ent = ge["entities"][0]
assert ent["vector"] == [0.1, 0.2, 0.3]
assert "vectors" not in ent
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
payload = msgpack.packb((
"t",
{
"m": {"i": "doc-1", "c": "testcoll"},
"t": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
triples = req["triples"]
assert "metadata" not in triples["metadata"] # no stale field
assert len(triples["triples"]) == 1
# ---------------------------------------------------------------------------
# Full round-trip: export bytes feed directly into import
# ---------------------------------------------------------------------------
class TestCoreImportExportRoundTrip:
"""End-to-end: produce bytes via core_export, consume them via
core_import, and verify the dict that lands at the import-side
translator is structurally equivalent to what went in. This is the
test that catches asymmetries between the two halves."""
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_graph_embeddings_round_trip(
self, mock_export_kr_class, mock_import_kr_class,
):
# ----- export side: capture bytes -----
export_bytes = []
async def fake_export_process(req_dict, responder):
await responder(_ge_response_dict(), True)
export_kr = AsyncMock()
export_kr.start = AsyncMock()
export_kr.stop = AsyncMock()
export_kr.process = fake_export_process
mock_export_kr_class.return_value = export_kr
response = AsyncMock()
async def fake_write(b):
export_bytes.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=AsyncMock(),
ok=AsyncMock(return_value=response),
request=_make_request(),
)
assert len(export_bytes) == 1
# ----- import side: feed those bytes back in -----
import_captured = []
async def fake_import_process(req_dict):
import_captured.append(req_dict)
import_kr = AsyncMock()
import_kr.start = AsyncMock()
import_kr.stop = AsyncMock()
import_kr.process = fake_import_process
mock_import_kr_class.return_value = import_kr
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(export_bytes[0]),
error=AsyncMock(),
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
request=_make_request(),
)
# ----- verify the dict the importer would hand to the translator -----
assert len(import_captured) == 1
req = import_captured[0]
original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"]
# The import side overrides id from the URL query (intentional),
# so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"]
assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]):
assert got["vector"] == want["vector"]
assert got["entity"] == want["entity"]

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data)
assert "flow1" in manager.flows
assert manager.flows["flow1"] == flow_data
await manager.start_flow("default", "flow1", flow_data)
assert ("default", "flow1") in manager.flows
assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio
async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data
await manager.stop_flow("flow1", flow_data)
assert "flow1" not in manager.flows
manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("default", "flow1", flow_data)
assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper"""
@ -275,12 +275,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -290,7 +290,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
params = {"flow": "test_flow", "kind": "triples"}
result = await manager.process_flow_import("ws", "running", params)
@ -298,7 +298,7 @@ class TestDispatcherManager:
backend=mock_backend,
ws="ws",
running="running",
queue={"queue": "test_queue"}
queue="test_queue"
)
mock_dispatcher.start.assert_called_once()
assert result == mock_dispatcher
@ -326,12 +326,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
mock_dispatchers.__contains__.return_value = False
@ -348,12 +348,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -370,7 +370,7 @@ class TestDispatcherManager:
backend=mock_backend,
ws="ws",
running="running",
queue={"queue": "test_queue"},
queue="test_queue",
consumer="api-gateway-test-uuid",
subscriber="api-gateway-test-uuid"
)
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result"
@pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}}
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -443,7 +443,7 @@ class TestDispatcherManager:
}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
@ -452,23 +452,23 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
request_queue="agent_request_queue",
response_queue="agent_response_queue",
timeout=120,
consumer="api-gateway-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request"
consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-default-test_flow-agent-request"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result"
@pytest.mark.asyncio
@ -479,36 +479,36 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-load": {"queue": "text_load_queue"}
"text-load": {"flow": "text_load_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
mock_rr_dispatchers.__contains__.return_value = False
mock_sender_dispatchers.__contains__.return_value = True
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock()
mock_dispatcher.process = AsyncMock(return_value="sender_result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
queue={"queue": "text_load_queue"}
queue="text_load_queue"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result"
@pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-completion": {"request": "req", "response": "resp"}
}
}
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"invalid-kind": {"request": "req", "response": "resp"}
}
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5)
])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -0,0 +1,75 @@
"""Tests for Gateway i18n pack endpoint."""
import json
from unittest.mock import MagicMock
import pytest
from aiohttp import web
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
class TestI18nPackEndpoint:
def test_i18n_endpoint_initialization(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint(
endpoint_path="/api/v1/i18n/packs/{lang}",
auth=mock_auth,
)
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
@pytest.mark.asyncio
async def test_i18n_endpoint_start_method(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
await endpoint.start()
def test_add_routes_registers_get_handler(self):
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint.add_routes(mock_app)
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
assert len(call_args) == 1
@pytest.mark.asyncio
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {"Authorization": "Token abc"}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPUnauthorized)
@pytest.mark.asyncio
async def test_handle_returns_pack_when_permitted(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert resp.status == 200
payload = json.loads(resp.body.decode("utf-8"))
assert isinstance(payload, dict)
assert "cli.verify_system_status.title" in payload

View file

@ -0,0 +1,241 @@
"""
Unit tests for entity contexts import dispatcher.
Tests the business logic of EntityContextsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version constructed Metadata(metadata=...)
which raised TypeError at runtime as soon as a message was received. These
tests exercise receive() end-to-end so any future schema/kwarg drift in
the Metadata or EntityContexts construction is caught immediately.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
from trustgraph.schema import EntityContexts, EntityContext, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample entity-contexts websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"context": "Alice is a person.",
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"context": "Bob is a person.",
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestEntityContextsImportInitialization:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ec-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ec-queue",
schema=EntityContexts,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestEntityContextsImportLifecycle:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestEntityContextsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_entity_contexts_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata or EntityContexts gain/lose kwargs, this raises
# TypeError — exactly the regression we want to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityContext) for e in sent.entities)
assert sent.entities[0].context == "Alice is a person."
assert sent.entities[1].context == "Bob is a person."
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, EntityContexts)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -158,7 +158,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="I need to think...",
)
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="The answer is...",
end_of_dialog=True,
)

View file

@ -0,0 +1,246 @@
"""
Unit tests for graph embeddings import dispatcher.
Tests the business logic of GraphEmbeddingsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version of EntityContextsImport
constructed Metadata(metadata=...) which raised TypeError at runtime as
soon as a message was received. The same shape of bug can occur here, so
these tests exercise receive() end-to-end to catch any future schema or
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample graph-embeddings websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"vector": [0.4, 0.5, 0.6],
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestGraphEmbeddingsImportInitialization:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ge-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ge-queue",
schema=GraphEmbeddings,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestGraphEmbeddingsImportLifecycle:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestGraphEmbeddingsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_graph_embeddings_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
# kwargs, this raises TypeError — exactly the regression we want
# to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
# Lock in the wire format: incoming "vector" key (singular,
# list[float]) maps to EntityEmbeddings.vector. This mirrors
# serialize_graph_embeddings() on the export side.
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
# Check metadata
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')

View file

@ -171,6 +171,14 @@ class TestApi:
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
api = Api(port=8080)
api.run()

View file

@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")

View file

@ -29,10 +29,9 @@ class Triple:
self.o = o
class Metadata:
def __init__(self, id, user, collection, root=""):
def __init__(self, id, collection, root=""):
self.id = id
self.root = root
self.user = user
self.collection = collection
class Triples:
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
"""Sample Triples batch object"""
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -123,7 +121,6 @@ def sample_chunk():
"""Sample text chunk for processing"""
metadata = Metadata(
id="test-chunk-456",
user="test_user",
collection="test_collection",
)

View file

@ -322,7 +322,6 @@ This is not JSON at all
assert isinstance(sent_triples, Triples)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject"
@ -346,7 +345,6 @@ This is not JSON at all
assert isinstance(sent_contexts, EntityContexts)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
"""Test ExtractedObject creation and properties"""
# Arrange
metadata = Metadata(
id="test-extraction-001",
user="test_user",
id="test-extraction-001",
collection="test_collection",
)
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user"
def test_config_parsing_error_handling(self):
"""Test configuration parsing with invalid JSON"""

View file

@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2

View file

@ -0,0 +1,119 @@
import asyncio
import io
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from uuid import uuid4
from minio.error import S3Error
from trustgraph.librarian.blob_store import BlobStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_blob_store():
"""Create a BlobStore with mocked Minio client."""
mock_minio = MagicMock()
with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio):
# Prevent ensure_bucket from making network calls during init
with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'):
store = BlobStore(
endpoint="localhost:9000",
access_key="access",
secret_key="secret",
bucket_name="test-bucket"
)
return store, mock_minio
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_add_success_no_retry():
store, mock_minio = _make_blob_store()
object_id = uuid4()
await store.add(object_id, b"data", "text/plain")
mock_minio.put_object.assert_called_once()
@pytest.mark.asyncio
async def test_retry_recovery_on_transient_failure():
store, mock_minio = _make_blob_store()
store.base_delay = 0 # Disable delay for fast tests
# Fail twice, succeed third time
mock_minio.put_object.side_effect = [
Exception("Error 1"),
Exception("Error 2"),
MagicMock()
]
await store.add(uuid4(), b"data", "text/plain")
assert mock_minio.put_object.call_count == 3
@pytest.mark.asyncio
async def test_retry_exhaustion_after_8_attempts():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Permanent failure
mock_minio.put_object.side_effect = Exception("Permanent failure")
with pytest.raises(Exception, match="Permanent failure"):
await store.add(uuid4(), b"data", "text/plain")
# Author requirement: exactly 8 attempts
assert mock_minio.put_object.call_count == 8
@pytest.mark.asyncio
async def test_s3_error_triggers_retry():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Mock S3Error
s3_err = S3Error("code", "msg", "res", "req", "host", None)
mock_minio.get_object.side_effect = [s3_err, MagicMock()]
await store.get(uuid4())
assert mock_minio.get_object.call_count == 2
@pytest.mark.asyncio
async def test_exponential_backoff_delays():
store, mock_minio = _make_blob_store()
# Use real base_delay to check math
store.base_delay = 0.25
# Correct method name is stat_object, not get_size
mock_minio.stat_object = MagicMock(side_effect=Exception("Wait"))
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
with pytest.raises(Exception):
await store.get_size(uuid4())
# Should have 7 sleep calls for 8 attempts
assert mock_sleep.call_count == 7
# Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0
sleep_args = [call[0][0] for call in mock_sleep.call_args_list]
assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
@pytest.mark.asyncio
async def test_runs_in_executor():
"""Verify that synchronous Minio calls are offloaded to an executor."""
store, mock_minio = _make_blob_store()
# Mock response object with .read() method
mock_response = MagicMock()
mock_response.read.return_value = b"result"
with patch('asyncio.get_event_loop') as mock_loop:
mock_loop_instance = MagicMock()
mock_loop.return_value = mock_loop_instance
mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response)
await store.get(uuid4())
mock_loop_instance.run_in_executor.assert_called_once()

View file

@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1):
"""Create a Librarian with mocked blob_store and table_store."""
lib = Librarian.__new__(Librarian)
lib.blob_store = MagicMock()
lib.blob_store.create_multipart_upload = AsyncMock()
lib.blob_store.upload_part = AsyncMock()
lib.blob_store.complete_multipart_upload = AsyncMock()
lib.blob_store.abort_multipart_upload = AsyncMock()
lib.table_store = AsyncMock()
lib.load_document = AsyncMock()
lib.min_chunk_size = min_chunk_size
@ -29,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
):
meta = MagicMock()
meta.id = doc_id
meta.kind = kind
meta.user = user
meta.workspace = workspace
meta.title = title
meta.time = 1700000000
meta.comments = ""
@ -43,27 +47,27 @@ def _make_doc_metadata(
def _make_begin_request(
doc_id="doc-1", kind="application/pdf", user="alice",
doc_id="doc-1", kind="application/pdf", workspace="alice",
total_size=10_000_000, chunk_size=0
):
req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
req.total_size = total_size
req.chunk_size = chunk_size
return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
req = MagicMock()
req.upload_id = upload_id
req.chunk_index = chunk_index
req.user = user
req.workspace = workspace
req.content = base64.b64encode(content)
return req
def _make_session(
user="alice", total_chunks=5, chunk_size=2_000_000,
workspace="alice", total_chunks=5, chunk_size=2_000_000,
total_size=10_000_000, chunks_received=None, object_id="obj-1",
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
):
@ -72,11 +76,11 @@ def _make_session(
if document_metadata is None:
document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf",
"user": user, "title": "Test", "time": 1700000000,
"workspace": workspace, "title": "Test", "time": 1700000000,
"comments": "", "tags": [],
})
return {
"user": user,
"workspace": workspace,
"total_chunks": total_chunks,
"chunk_size": chunk_size,
"total_size": total_size,
@ -255,10 +259,10 @@ class TestUploadChunk:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(user="bob")
req = _make_upload_chunk_request(workspace="bob")
with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req)
@ -349,7 +353,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.complete_upload(req)
@ -371,7 +375,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
await lib.complete_upload(req)
@ -390,7 +394,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req)
@ -402,7 +406,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req)
@ -410,12 +414,12 @@ class TestCompleteUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req)
@ -435,7 +439,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.abort_upload(req)
@ -452,7 +456,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req)
@ -460,12 +464,12 @@ class TestAbortUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req)
@ -488,7 +492,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -506,7 +510,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-expired"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -523,7 +527,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -535,12 +539,12 @@ class TestGetUploadStatus:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req)
@ -560,7 +564,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -583,7 +587,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -604,7 +608,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -626,7 +630,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB
@ -645,7 +649,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 1000
@ -662,7 +666,7 @@ class TestStreamDocument:
lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 512
@ -694,7 +698,7 @@ class TestListUploads:
]
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)
@ -709,7 +713,7 @@ class TestListUploads:
lib.table_store.list_upload_sessions.return_value = []
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)

View file

@ -0,0 +1,590 @@
"""
DAG structure tests for provenance chains.
Verifies that the wasDerivedFrom chain has the expected shape for each
service. These tests catch structural regressions when new entities are
inserted into the chain (e.g. PatternDecision between session and first
iteration).
Expected chains:
GraphRAG: question grounding exploration focus synthesis
DocumentRAG: question grounding exploration synthesis
Agent React: session pattern-decision iteration (observation iteration)* final
Agent Plan: session pattern-decision plan step-result(s) synthesis
Agent Super: session pattern-decision decomposition (fan-out) finding(s) synthesis
"""
import json
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
TG_OBSERVATION_TYPE,
TG_PATTERN, TG_TASK_TYPE,
)
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_events(events):
"""Build a dict of explain_id → {types, derived_from, triples}."""
result = {}
for ev in events:
eid = ev["explain_id"]
triples = ev["triples"]
types = {
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == RDF_TYPE
}
parents = [
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
]
result[eid] = {
"types": types,
"derived_from": parents[0] if parents else None,
"triples": triples,
}
return result
def _find_by_type(dag, rdf_type):
"""Find all event IDs that have the given rdf:type."""
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
def _assert_chain(dag, chain_types):
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
for i in range(1, len(chain_types)):
parent_type = chain_types[i - 1]
child_type = chain_types[i]
parents = _find_by_type(dag, parent_type)
children = _find_by_type(dag, child_type)
assert parents, f"No entity with type {parent_type}"
assert children, f"No entity with type {child_type}"
# At least one child must derive from at least one parent
linked = False
for child_id in children:
derived = dag[child_id]["derived_from"]
if derived in parents:
linked = True
break
assert linked, (
f"No {child_type} derives from {parent_type}. "
f"Children derive from: "
f"{[dag[c]['derived_from'] for c in children]}"
)
# ---------------------------------------------------------------------------
# GraphRAG DAG structure
# ---------------------------------------------------------------------------
class TestGraphRagDagStructure:
"""Verify: question → grounding → exploration → focus → synthesis"""
@pytest.fixture
def mock_clients(self):
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
]
triples_client.query_stream.return_value = [
Triple(
s=Term(type=IRI, iri="http://example.com/e1"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="value"),
)
]
triples_client.query.return_value = []
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = GraphRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
)
dag = _collect_events(events)
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
_assert_chain(dag, [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# DocumentRAG DAG structure
# ---------------------------------------------------------------------------
class TestDocumentRagDagStructure:
"""Verify: question → grounding → exploration → synthesis"""
@pytest.fixture
def mock_clients(self):
from trustgraph.schema import ChunkMatch
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock(return_value="Chunk content.")
embeddings_client.embed.return_value = [[0.1, 0.2]]
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.9),
]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
prompt_client.document_prompt.return_value = PromptResult(
response_type="text", text="Answer.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = DocumentRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
_assert_chain(dag, [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# Agent DAG structure — tested via service.agent_request()
# ---------------------------------------------------------------------------
def _make_processor(tools=None):
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
def mock_session_uri(sid):
return f"urn:trustgraph:agent:session:{sid}"
processor.provenance_session_uri.side_effect = mock_session_uri
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agents = {"default": agent}
processor.aggregator = MagicMock()
return processor
def _make_flow():
producers = {}
def factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=factory)
flow.workspace = "default"
return flow
def _collect_agent_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"triples": resp.explain_triples,
})
return events
class TestAgentReactDagStructure:
"""
Via service.agent_request(), full two-iteration react chain:
session pattern-decision iteration(1) observation(1) final
Iteration 1: tool call observation
Iteration 2: final answer
"""
def _make_service(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
mock_tool = MagicMock()
mock_tool.name = "lookup"
mock_tool.description = "Look things up"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"lookup": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
return service
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.react.types import Action, Final
service = self._make_service()
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
session_id = str(uuid.uuid4())
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
action = Action(
thought="I need to look this up",
name="lookup",
arguments={"question": "6x7"},
observation="",
)
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter1(on_action=None, **kwargs):
if on_action:
await on_action(action)
action.observation = "42"
return action
mock_am.react.side_effect = mock_react_iter1
request1 = AgentRequest(
question="What is 6x7?",
collection="default",
streaming=False,
session_id=session_id,
pattern="react",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# next_fn should have been called with updated history
assert next_fn.called
# Iteration 2: final answer
final = Final(thought="The answer is 42", final="42")
next_request = next_fn.call_args[0][0]
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter2(**kwargs):
return final
mock_am.react.side_effect = mock_react_iter2
await service.agent_request(next_request, respond, next_fn, flow)
# Collect and verify DAG
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
final_ids = _find_by_type(dag, TG_CONCLUSION)
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
# Full chain:
# session → pattern-decision
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
# pattern-decision → iteration(1)
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
# iteration(1) → observation(1)
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
# observation(1) → final
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
class TestAgentPlanDagStructure:
"""
Via service.agent_request():
session pattern-decision plan step-result synthesis
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
# Mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found it")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"knowledge-query": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
# Mock prompt client
mock_prompt_client = AsyncMock()
call_count = 0
async def mock_prompt(id, variables=None, **kwargs):
nonlocal call_count
call_count += 1
if id == "plan-create":
return PromptResult(
response_type="jsonl",
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
)
elif id == "plan-step-execute":
return PromptResult(
response_type="json",
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
)
elif id == "plan-synthesise":
return PromptResult(response_type="text", text="Final answer.")
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt.side_effect = mock_prompt
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
session_id = str(uuid.uuid4())
# Iteration 1: planning
request1 = AgentRequest(
question="Test?",
collection="default",
streaming=False,
session_id=session_id,
pattern="plan-then-execute",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# Iteration 2: execute step (next_fn was called with updated request)
assert next_fn.called
next_request = next_fn.call_args[0][0]
# Iteration 3: all steps done → synthesis
# Simulate completed step in history
next_request.history[-1].plan[0].status = "completed"
next_request.history[-1].plan[0].result = "Found it"
await service.agent_request(next_request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(plan_ids) == 1
assert len(synthesis_ids) == 1
# Chain: session → pattern-decision → plan → ... → synthesis
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
class TestAgentSupervisorDagStructure:
"""
Via service.agent_request():
session pattern-decision decomposition (fan-out)
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = _make_processor()
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B"],
)
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = AgentRequest(
question="Research quantum computing",
collection="default",
streaming=False,
session_id=str(uuid.uuid4()),
pattern="supervisor",
history=[],
)
await service.agent_request(request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(decomp_ids) == 1
# Chain: session → pattern-decision → decomposition
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
# Fan-out should have been called
assert next_fn.call_count == 2 # One per goal

View file

@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self):
def test_chunk_entity_has_message_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",

View file

@ -0,0 +1,131 @@
"""
Unit tests for Kafka backend topic parsing and factory dispatch.
Does not require a running Kafka instance.
"""
import pytest
import argparse
from trustgraph.base.kafka_backend import KafkaBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestKafkaParseTopic:
@pytest.fixture
def backend(self):
b = object.__new__(KafkaBackend)
return b
def test_flow_is_durable(self, backend):
name, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
assert durable is True
assert cls == 'flow'
assert name == 'tg.flow.text-completion-request'
def test_notify_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('notify:tg:config')
assert durable is False
assert cls == 'notify'
assert name == 'tg.notify.config'
def test_request_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('request:tg:config')
assert durable is False
assert cls == 'request'
assert name == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('response:tg:librarian')
assert durable is False
assert cls == 'response'
assert name == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, cls, durable = backend._parse_topic('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, cls, durable = backend._parse_topic('simple-queue')
assert name == 'tg.flow.simple-queue'
assert cls == 'flow'
assert durable is True
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid topic class"):
backend._parse_topic('unknown:tg:topic')
def test_topic_with_flow_suffix(self, backend):
"""Topic names with flow suffix (e.g. :default) have colons replaced with dots."""
name, cls, durable = backend._parse_topic('request:tg:prompt:default')
assert name == 'tg.request.prompt.default'
class TestKafkaRetention:
@pytest.fixture
def backend(self):
b = object.__new__(KafkaBackend)
return b
def test_flow_gets_long_retention(self, backend):
assert backend._retention_ms('flow') == 7 * 24 * 60 * 60 * 1000
def test_request_gets_short_retention(self, backend):
assert backend._retention_ms('request') == 300 * 1000
def test_response_gets_short_retention(self, backend):
assert backend._retention_ms('response') == 300 * 1000
def test_notify_gets_short_retention(self, backend):
assert backend._retention_ms('notify') == 300 * 1000
class TestGetPubsubKafka:
def test_factory_creates_kafka_backend(self):
backend = get_pubsub(pubsub_backend='kafka')
assert isinstance(backend, KafkaBackend)
def test_factory_passes_config(self):
backend = get_pubsub(
pubsub_backend='kafka',
kafka_bootstrap_servers='myhost:9093',
kafka_security_protocol='SASL_SSL',
kafka_sasl_mechanism='PLAIN',
kafka_sasl_username='user',
kafka_sasl_password='pass',
)
assert isinstance(backend, KafkaBackend)
assert backend._bootstrap_servers == 'myhost:9093'
assert backend._admin_config['security.protocol'] == 'SASL_SSL'
assert backend._admin_config['sasl.mechanism'] == 'PLAIN'
assert backend._admin_config['sasl.username'] == 'user'
assert backend._admin_config['sasl.password'] == 'pass'
class TestAddPubsubArgsKafka:
def test_kafka_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'kafka',
'--kafka-bootstrap-servers', 'myhost:9093',
])
assert args.pubsub_backend == 'kafka'
assert args.kafka_bootstrap_servers == 'myhost:9093'
def test_kafka_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.kafka_bootstrap_servers == 'kafka:9092'
assert args.kafka_security_protocol == 'PLAINTEXT'
def test_kafka_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.kafka_bootstrap_servers == 'localhost:9092'

View file

@ -1,18 +1,16 @@
"""
Unit tests for RabbitMQ backend queue name mapping and factory dispatch.
Unit tests for RabbitMQ backend topic parsing and factory dispatch.
Does not require a running RabbitMQ instance.
"""
import pytest
import argparse
pika = pytest.importorskip("pika", reason="pika not installed")
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestRabbitMQMapQueueName:
class TestRabbitMQParseTopic:
@pytest.fixture
def backend(self):
@ -20,43 +18,48 @@ class TestRabbitMQMapQueueName:
return b
def test_flow_is_durable(self, backend):
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
exchange, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
assert durable is True
assert name == 'tg.flow.text-completion-request'
assert cls == 'flow'
assert exchange == 'tg.flow.text-completion-request'
def test_notify_is_not_durable(self, backend):
name, durable = backend.map_queue_name('notify:tg:config')
exchange, cls, durable = backend._parse_topic('notify:tg:config')
assert durable is False
assert name == 'tg.notify.config'
assert cls == 'notify'
assert exchange == 'tg.notify.config'
def test_request_is_not_durable(self, backend):
name, durable = backend.map_queue_name('request:tg:config')
exchange, cls, durable = backend._parse_topic('request:tg:config')
assert durable is False
assert name == 'tg.request.config'
assert cls == 'request'
assert exchange == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, durable = backend.map_queue_name('response:tg:librarian')
exchange, cls, durable = backend._parse_topic('response:tg:librarian')
assert durable is False
assert name == 'tg.response.librarian'
assert cls == 'response'
assert exchange == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, durable = backend.map_queue_name('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
exchange, cls, durable = backend._parse_topic('flow:prod:my-queue')
assert exchange == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, durable = backend.map_queue_name('simple-queue')
assert name == 'tg.simple-queue'
assert durable is False
exchange, cls, durable = backend._parse_topic('simple-queue')
assert exchange == 'tg.flow.simple-queue'
assert cls == 'flow'
assert durable is True
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_queue_name('unknown:tg:topic')
with pytest.raises(ValueError, match="Invalid topic class"):
backend._parse_topic('unknown:tg:topic')
def test_flow_with_flow_suffix(self, backend):
"""Queue names with flow suffix (e.g. :default) are preserved."""
name, durable = backend.map_queue_name('request:tg:prompt:default')
assert name == 'tg.request.prompt:default'
def test_topic_with_flow_suffix(self, backend):
"""Topic names with flow suffix (e.g. :default) are preserved."""
exchange, cls, durable = backend._parse_topic('request:tg:prompt:default')
assert exchange == 'tg.request.prompt:default'
class TestGetPubsubRabbitMQ:

View file

@ -304,14 +304,14 @@ class TestStreamingTypes:
assert chunk.content == "thinking..."
assert chunk.end_of_message is False
assert chunk.chunk_type == "thought"
assert chunk.message_type == "thought"
def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..."
assert chunk.chunk_type == "observation"
assert chunk.message_type == "observation"
def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk"""
@ -324,7 +324,7 @@ class TestStreamingTypes:
assert chunk.content == "answer"
assert chunk.end_of_message is True
assert chunk.end_of_dialog is True
assert chunk.chunk_type == "final-answer"
assert chunk.message_type == "final-answer"
def test_rag_chunk_creation(self):
"""Test creating RAGChunk"""

View file

@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_with_limit(self, processor):
"""Test querying document embeddings respects limit parameter"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with(
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors(self, processor):
"""Test querying document embeddings with empty vectors list"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_search_results(self, processor):
"""Test querying document embeddings with empty search results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_documents(self, processor):
"""Test querying document embeddings with Unicode document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_documents(self, processor):
"""Test querying document embeddings with large document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_special_characters(self, processor):
"""Test querying document embeddings with special characters in documents"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying document embeddings with zero limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying document embeddings with negative limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=-1
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
await processor.query_document_embeddings('test_user', query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying document embeddings with different vector dimensions"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify results are ChunkMatch objects
assert len(result) == 3

View file

@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
chunks = await processor.query_document_embeddings('default', mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify empty results
assert chunks == []
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify large content is properly handled
assert len(chunks) == 1
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all content types are properly handled
assert len(chunks) == 5
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all queries were made
assert mock_index.query.call_count == 3

View file

@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with exact limit (no multiplication)
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once with correct collection
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'utf8_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('utf8_user', mock_message)
# Assert
assert len(result) == 2
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('error_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'large_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('large_user', mock_message)
# Assert
# Should query with full limit
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('payload_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')

View file

@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_with_limit(self, processor):
"""Test querying graph embeddings respects limit parameter"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with(
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify results are in the same order as returned by the store
assert len(result) == 3
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_vectors(self, processor):
"""Test querying graph embeddings with empty vectors list"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_search_results(self, processor):
"""Test querying graph embeddings with empty search results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
"""Test querying graph embeddings with mixed URI and literal results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify all results are properly typed
assert len(result) == 4
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
await processor.query_graph_embeddings('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying graph embeddings with zero limit"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()

View file

@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
entities = await processor.query_graph_embeddings('default', mock_query_message)
# Verify query was made once
assert mock_index.query.call_count == 1
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify limit is respected
assert len(entities) == 2
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify empty results
assert entities == []
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
await processor.query_graph_embeddings('test_user', message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with limit * 2
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'uri_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('uri_user', mock_message)
# Assert
assert len(result) == 3
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_graph_embeddings(mock_message)
await processor.query_graph_embeddings('error_user', mock_message)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
class TestMemgraphQueryWorkspaceCollectionIsolation:
"""Test cases for Memgraph query service with user/collection isolation"""
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 1000"
)
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 1000"
)
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 1000"
)
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 1000"
)
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 1000"
)
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 1000"
)
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
class TestNeo4jQueryWorkspaceCollectionIsolation:
"""Test cases for Neo4j query service with user/collection isolation"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 10"
)
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 10"
)
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT 10"
)
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_node_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 10"
)
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 10"
)
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.query_cassandra = MagicMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
}
# Process config
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert "customer" in processor.schemas["default"]
schema = processor.schemas["default"]["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
# Verify per-workspace schema builder was created and graphql schema built
assert "default" in processor.schema_builders
assert "default" in processor.graphql_schemas
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
graphql_schema.execute.assert_called_once()
call_args = graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["workspace"] == "default"
assert context["collection"] == "test_collection"
# Verify result structure
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name } }',
variables={},
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
# Mock flow
mock_flow = MagicMock()
mock_flow.workspace = "default"
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
variables={},
@ -330,7 +330,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@ -340,10 +341,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
# Mock async_execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
mock_async_execute.return_value = [mock_row]
schema = RowSchema(
name="products",
@ -356,7 +357,7 @@ class TestUnifiedTableQueries:
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,
@ -366,14 +367,14 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
mock_async_execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
call_args = mock_async_execute.call_args
query = call_args[0][1]
params = call_args[0][2]
assert "SELECT data, source FROM test_user.rows" in query
assert "SELECT data, source FROM test_workspace.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
@ -390,7 +391,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_execute):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@ -401,12 +403,12 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
# Mock async_execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
mock_async_execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
@ -419,7 +421,7 @@ class TestUnifiedTableQueries:
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,
@ -428,8 +430,8 @@ class TestUnifiedTableQueries:
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
call_args = mock_async_execute.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query

View file

@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with correct parameters
mock_kg_class.assert_called_once_with(
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
limit=10
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
limit=75
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with authentication
mock_kg_class.assert_called_once_with(
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
)
# First query should create TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query1)
await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query2)
await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
assert len(result) == 2
assert result[0].o.value == 'object1'
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
# PO query pattern (predicate + object, find subjects)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
# OS query pattern (object + subject, find predicates)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
mock_tg_instance.reset_mock()
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value=s) if s else None,
p=Term(type=LITERAL, value=p) if p else None,
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=10
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify the correct method was called
method = getattr(mock_tg_instance, expected_method)
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
# This is the query pattern that was slow with ALLOW FILTERING
query = TriplesQueryRequest(
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(

Some files were not shown because too many files have changed in this diff Show more