IAM tech spec: Auth and access management current state and proposed

changes.

Support for separate workspaces

Addition of workspace CLI support for test purposes
This commit is contained in:
Cyber MacGeddon 2026-04-18 23:07:26 +01:00
parent 48da6c5f8b
commit db05427d0e
219 changed files with 4875 additions and 2616 deletions

View file

@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
"""Test basic agent integration with structured query tool"""
# Arrange - Load tool configuration
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Create agent request
request = AgentRequest(
@ -119,6 +119,7 @@ Args: {
# Mock flow parameter in agent_processor.on_request
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -146,7 +147,7 @@ Args: {
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
"""Test agent handling of structured query errors"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.",
@ -199,6 +200,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -221,7 +223,7 @@ Args: {
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
"""Test agent using structured query in multi-step reasoning"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.",
@ -279,6 +281,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -313,7 +316,7 @@ Args: {
}
}
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
request = AgentRequest(
question="Query the sales data for recent transactions.",
@ -371,6 +374,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -394,10 +398,10 @@ Args: {
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
"""Test that structured query tool arguments are properly validated"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Check that the tool was registered with correct arguments
tools = agent_processor.agent.tools
tools = agent_processor.agents["default"].tools
assert "structured-query" in tools
structured_tool = tools["structured-query"]
@ -414,7 +418,7 @@ Args: {
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
"""Test that structured query results are properly formatted for agent consumption"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Get customer information and format it nicely.",
@ -482,6 +486,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)

View file

@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
)
# Set up schemas
proc.schemas = sample_schemas
proc.schemas = {"default": dict(sample_schemas)}
# Mock the client method
proc.client = MagicMock()
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
}
# Act - Update configuration
await integration_processor.on_schema_config(new_schema_config, "v2")
await integration_processor.on_schema_config("default", new_schema_config, "v2")
# Arrange - Test query using new schema
request = QuestionToStructuredQueryRequest(
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
await integration_processor.on_message(msg, consumer, flow)
# Assert
assert "inventory" in integration_processor.schemas
assert "inventory" in integration_processor.schemas["default"]
response_call = flow_response.send.call_args
response = response_call[0][0]
assert response.detected_schemas == ["inventory"]
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
graphql_generation_template="custom-graphql-generator"
)
custom_processor.schemas = sample_schemas
custom_processor.schemas = {"default": dict(sample_schemas)}
custom_processor.client = MagicMock()
request = QuestionToStructuredQueryRequest(
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
)
integration_processor.schemas.update(large_schema_set)
integration_processor.schemas["default"].update(large_schema_set)
request = QuestionToStructuredQueryRequest(
question="Show me data from table_05 and table_12",
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response

View file

@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.mark.asyncio
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Act
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
# Verify customer schema
customer_schema = processor.schemas["customer_records"]
customer_schema = ws_schemas["customer_records"]
assert customer_schema.name == "customer_records"
assert len(customer_schema.fields) == 4
# Verify product schema
product_schema = processor.schemas["product_catalog"]
product_schema = ws_schemas["product_catalog"]
assert product_schema.name == "product_catalog"
assert len(product_schema.fields) == 4
@ -237,7 +239,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic customer data chunk
metadata = Metadata(
@ -304,7 +306,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic product data chunk
metadata = Metadata(
@ -368,7 +370,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create multiple test chunks
chunks_data = [
@ -431,19 +433,21 @@ class TestObjectExtractionServiceIntegration:
"customer_records": integration_config["schema"]["customer_records"]
}
}
await processor.on_schema_config(initial_config, version=1)
assert len(processor.schemas) == 1
assert "customer_records" in processor.schemas
assert "product_catalog" not in processor.schemas
await processor.on_schema_config("default", initial_config, version=1)
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 1
assert "customer_records" in ws_schemas
assert "product_catalog" not in ws_schemas
# Act - Reload with full configuration
await processor.on_schema_config(integration_config, version=2)
await processor.on_schema_config("default", integration_config, version=2)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
@pytest.mark.asyncio
async def test_error_resilience_integration(self, integration_config):
@ -474,10 +478,11 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
failing_flow.side_effect = failing_context_router
failing_flow.workspace = "default"
processor.flow = failing_flow
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test")
@ -510,7 +515,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create chunk with rich metadata
original_metadata = Metadata(

View file

@ -87,6 +87,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.fixture
@ -109,7 +110,7 @@ class TestPromptStreaming:
def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support"""
processor = MagicMock()
processor.manager = mock_prompt_manager
processor.managers = {"default": mock_prompt_manager}
processor.config_key = "prompt"
# Bind the actual on_request method
@ -248,6 +249,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",
@ -341,6 +343,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",

View file

@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@ -125,8 +136,8 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
await processor.on_schema_config("default", config, version=1)
assert "customer_records" in processor.schemas["default"]
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
@ -149,7 +160,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify Cassandra interactions
assert mock_cluster.connect.called
@ -209,8 +220,8 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
await processor.on_schema_config("default", config, version=1)
assert len(processor.schemas["default"]) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
@ -233,7 +244,7 @@ class TestRowsCassandraIntegration:
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
@ -256,15 +267,17 @@ class TestRowsCassandraIntegration:
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
processor.schemas["default"] = {
"indexed_data": RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
@ -282,7 +295,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -342,7 +355,7 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
@ -376,7 +389,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
@ -396,10 +409,12 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
processor.schemas["default"] = {
"empty_test": RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
# Process empty batch object
empty_obj = ExtractedObject(
@ -413,7 +428,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
@ -428,14 +443,16 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
processor.schemas["default"] = {
"map_test": RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
@ -448,7 +465,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -473,13 +490,15 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
processor.schemas["default"] = {
"partition_test": RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection"),
@ -492,7 +511,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list

View file

@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
"""Test schema configuration loading and GraphQL schema generation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Verify schemas were loaded
assert len(processor.schemas) == 2
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
"""Test Cassandra connection and dynamic table creation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra
processor.connect_cassandra()
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
"""Test inserting data and querying via GraphQL"""
# Load schema and connect
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Setup test data
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
"""Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup)
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "test_user"
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_error_handling(self, processor, sample_schema_config):
"""Test GraphQL error handling for invalid queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Test invalid field query
invalid_query = '''
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_message_processing_integration(self, processor, sample_schema_config):
"""Test full message processing workflow"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create mock message
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_concurrent_queries(self, processor, sample_schema_config):
"""Test handling multiple concurrent GraphQL queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create multiple query tasks
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(initial_config, version=1)
await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1
assert "simple" in processor.schemas
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(updated_config, version=2)
await processor.on_schema_config("default", updated_config, version=2)
# Verify updated schemas
assert len(processor.schemas) == 2
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_large_result_set_handling(self, processor, sample_schema_config):
"""Test handling of large query result sets"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "large_test_user"
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
}
}
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Measure query execution time
start_time = time.time()

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -58,6 +61,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -129,6 +133,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -148,6 +155,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()

View file

@ -241,7 +241,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
@ -260,6 +260,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +281,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
@ -298,6 +299,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +319,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
@ -330,6 +332,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +353,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
@ -362,6 +365,7 @@ class TestToolServiceOnRequest:
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +380,8 @@ class TestToolServiceOnRequest:
received = {}
async def capture_invoke(name, params):
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
@ -390,6 +395,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(

View file

@ -1,17 +1,14 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
- on_config_notify version comparison, type/workspace matching
- fetch_and_apply_config retry logic over per-workspace fetches
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify:
@pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
msg = _notify_msg(3, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
msg = _notify_msg(5, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
msg = _notify_msg(2, {"schema": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={"key": "value"},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None)
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called()
all_handler.assert_called_once()
all_handler.assert_not_called()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
async def test_multi_workspace_notify_invokes_handler_per_ws(
self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
mock_config = {}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
assert handler.call_count == 2
called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
side_effect=RuntimeError("Connection failed"),
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
async def test_applies_config_per_workspace(self, processor):
"""Startup fetch invokes handler once per workspace affected."""
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert h.call_count == 2
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def test_handler_without_types_skipped_at_startup(self, processor):
"""Handlers registered without types fetch nothing at startup."""
typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch():
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
mock_client = AsyncMock()
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1"
assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {}
assert flow.consumer == {}
assert flow.parameter == {}
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock()
consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start())
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value"

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)

View file

@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
assert flow_name in processor.flows
assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', flow_name, processor, flow_defn
'test-processor', flow_name, "default", processor, flow_defn
)
mock_flow.start.assert_called_once()
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
await processor.start_flow(flow_name, {'config': 'test-config'})
await processor.start_flow("default", flow_name, {'config': 'test-config'})
await processor.stop_flow(flow_name)
await processor.stop_flow("default", flow_name)
assert flow_name not in processor.flows
assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once()
@with_async_processor_patches
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
processor = FlowProcessor(**config)
await processor.stop_flow('non-existent-flow')
await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {}
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert 'test-flow' in processor.flows
assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', processor,
'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'}
)
mock_flow.start.assert_called_once()
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
'other-data': 'some-value'
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data1, version=1)
await processor.on_configure_flows("default", config_data1, version=1)
config_data2 = {
'processor:test-processor': {
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data2, version=2)
await processor.on_configure_flows("default", config_data2, version=2)
assert 'flow1' not in processor.flows
assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once()
assert 'flow2' in processor.flows
assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once()
@with_async_processor_patches

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com',
config_type='prompt',
format_type='json',
token=None
token=None,
workspace='default'
)
def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/',
config_type='prompt',
format_type='text',
token=None
token=None,
workspace='default'
)
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt',
key='template-1',
format_type='json',
token=None
token=None,
workspace='default'
)
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt',
key='new-template',
value='Custom prompt: {input}',
token=None
token=None,
workspace='default'
)
def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt',
key='stdin-template',
value=stdin_content,
token=None
token=None,
workspace='default'
)
def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com',
config_type='prompt',
key='old-template',
token=None
token=None,
workspace='default'
)

View file

@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters
mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
token="test-token",
workspace="default"
)
# Verify bulk client was obtained

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None)
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery:

View file

@ -28,10 +28,12 @@ def mock_flow_config():
"""Mock flow configuration."""
mock_config = Mock()
mock_config.flows = {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
"test-user": {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
}
}
}
}

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
assert 'customers' in processor.schemas["default"]
assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {}
assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
@ -325,14 +325,16 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
processor.schemas["default"] = {
'customers': RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
}
metadata = MagicMock()
metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver:
"""Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Notify with newer version fetches each affected workspace."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1
msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
"""Older-version notifies are ignored."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=3, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0
msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
"""Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch
assert len(fetch_calls) == 0
msg = _notify(2, {"prompt": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
"""on_config_notify swallows exceptions from message decode."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
"flow2": '{"name": "test_flow_2"}',
}
}
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
"flow1": '{"name": "test_flow_1"}',
}
}
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Empty workspace config clears any local flow state."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {}
assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
"""start_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.start_flow = Mock()
handler1.start_flow = AsyncMock()
handler2 = Mock()
handler2.start_flow = Mock()
handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
handler1.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions"""
"""Handler exceptions in start_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data)
handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers"""
"""stop_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.stop_flow = Mock()
handler1.stop_flow = AsyncMock()
handler2 = Mock()
handler2.stop_flow = Mock()
handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
handler1.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions"""
"""Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data)
handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task')
@pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
"flow3": '{"name": "test_flow_3"}',
}
}
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_calls.append((workspace, id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
ws_flows = config_receiver.flows["default"]
assert "flow1" not in ws_flows
assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"
assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data)
assert "flow1" in manager.flows
assert manager.flows["flow1"] == flow_data
await manager.start_flow("default", "flow1", flow_data)
assert ("default", "flow1") in manager.flows
assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio
async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data
await manager.stop_flow("flow1", flow_data)
assert "flow1" not in manager.flows
manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("default", "flow1", flow_data)
assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper"""
@ -275,12 +275,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -290,7 +290,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
params = {"flow": "test_flow", "kind": "triples"}
result = await manager.process_flow_import("ws", "running", params)
@ -326,12 +326,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
mock_dispatchers.__contains__.return_value = False
@ -348,12 +348,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result"
@pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}}
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -443,7 +443,7 @@ class TestDispatcherManager:
}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
@ -452,23 +452,23 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
request_queue="agent_request_queue",
response_queue="agent_response_queue",
timeout=120,
consumer="api-gateway-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request"
consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-default-test_flow-agent-request"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result"
@pytest.mark.asyncio
@ -479,26 +479,26 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-load": {"flow": "text_load_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
mock_rr_dispatchers.__contains__.return_value = False
mock_sender_dispatchers.__contains__.return_value = True
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock()
mock_dispatcher.process = AsyncMock(return_value="sender_result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
@ -506,9 +506,9 @@ class TestDispatcherManager:
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result"
@pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-completion": {"request": "req", "response": "resp"}
}
}
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"invalid-kind": {"request": "req", "response": "resp"}
}
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5)
])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -239,7 +239,7 @@ def _make_processor(tools=None):
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
processor.agents = {"default": agent}
processor.aggregator = MagicMock()
return processor
@ -254,6 +254,7 @@ def _make_flow():
return producers[name]
flow = MagicMock(side_effect=factory)
flow.workspace = "default"
return flow
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -433,7 +434,7 @@ class TestAgentPlanDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -537,7 +538,7 @@ class TestAgentSupervisorDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)

View file

@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.query_cassandra = MagicMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
}
# Process config
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert "customer" in processor.schemas["default"]
schema = processor.schemas["default"]["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
@ -147,24 +146,26 @@ class TestRowsGraphQLQueryLogic:
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
# Verify per-workspace schema builder was created and graphql schema built
assert "default" in processor.schema_builders
assert "default" in processor.graphql_schemas
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
@ -173,8 +174,8 @@ class TestRowsGraphQLQueryLogic:
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
graphql_schema.execute.assert_called_once()
call_args = graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value']
@ -190,7 +191,8 @@ class TestRowsGraphQLQueryLogic:
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error
@ -212,9 +214,10 @@ class TestRowsGraphQLQueryLogic:
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
@ -259,6 +262,7 @@ class TestRowsGraphQLQueryLogic:
# Mock flow
mock_flow = MagicMock()
mock_flow.workspace = "default"
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
@ -267,6 +271,7 @@ class TestRowsGraphQLQueryLogic:
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,

View file

@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert
assert "test_schema" in processor.schemas
schema = processor.schemas["test_schema"]
assert "test_schema" in processor.schemas["default"]
schema = processor.schemas["default"]["test_schema"]
assert schema.name == "test_schema"
assert schema.description == "Test schema"
assert len(schema.fields) == 2
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert - bad schema should be ignored
assert "bad_schema" not in processor.schemas
assert "bad_schema" not in processor.schemas.get("default", {})
def test_processor_initialization(self, mock_pulsar_client):
"""Test processor initialization with correct specifications"""

View file

@ -101,7 +101,7 @@ def service(mock_schemas):
taskgroup=MagicMock(),
id="test-processor"
)
service.schemas = mock_schemas
service.schemas = {"default": dict(mock_schemas)}
return service
@ -109,6 +109,7 @@ def service(mock_schemas):
def mock_flow():
"""Create mock flow with prompt service"""
flow = MagicMock()
flow.workspace = "default"
prompt_request_flow = AsyncMock()
flow.return_value.request = prompt_request_flow
return flow, prompt_request_flow

View file

@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
}
# Process configuration
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert "customer_records" in processor.schemas["default"]
schema = processor.schemas["default"]["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
@ -165,14 +176,16 @@ class TestRowsCassandraStorageLogic:
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -205,7 +218,7 @@ class TestRowsCassandraStorageLogic:
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert was executed
mock_async_execute.assert_called()
@ -230,14 +243,16 @@ class TestRowsCassandraStorageLogic:
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
"default": {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -267,7 +282,7 @@ class TestRowsCassandraStorageLogic:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per indexed field: id, category, status)
assert mock_async_execute.call_count == 3
@ -290,13 +305,15 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
"default": {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -331,7 +348,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert mock_async_execute.call_count == 3
@ -349,10 +366,12 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
"default": {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
@ -381,7 +400,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@ -446,19 +465,21 @@ class TestPartitionRegistration:
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
@ -473,7 +494,7 @@ class TestPartitionRegistration:
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()